Skip to content

Commit

Permalink
Merge branch 'new-pipeline' of https://www.github.com/ivogabe/accelerate
Browse files Browse the repository at this point in the history
 into new-pipeline-david
  • Loading branch information
dpvanbalen committed Jun 7, 2024
2 parents 7022acd + 00c91e0 commit 3eb258e
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 155 deletions.
3 changes: 2 additions & 1 deletion src/Data/Array/Accelerate/Pretty/Exp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ prettyPreOpenExp ctx prettyArrayInstr env exp =
--
single = parensIf (needsParens ctx (Operator "?:" Infix N 0))
$ sep [ p', pretty '?', t', pretty ':', e' ]
multi = hang 3
multi = parensIf (ctxPrecedence ctx > 0)
$ hang 3
$ vsep [ if_ <+> p'
, hang shiftwidth (sep [ then_, t' ])
, hang shiftwidth (sep [ else_, e' ]) ]
Expand Down
166 changes: 33 additions & 133 deletions src/Data/Array/Accelerate/Trafo/Desugar.hs
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,7 @@ class NFData' op => DesugarAcc (op :: Type -> Type) where
-> Arg env (In (sh, Int) e)
-> Arg env (Out sh e)
-> OperationAcc op env ()
mkFold f def input@(ArgArray _ repr@(ArrayR shr tp) _ _) output
-- Binding of the result of the first step
| DeclareVars lhs1 k1 value1 <- declareVars $ desugarArrayR repr
-- Binding in the iteration and condition function for the first step
, DeclareVars lhsSh kSh valueSh <- declareVars $ mapTupR GroundRscalar $ shapeType shr
, DeclareVars lhsBf kBf valueBf <- declareVars $ buffersR tp
= let
lhs2 = LeftHandSidePair lhsSh lhsBf
argTmp = ArgArray In (ArrayR shr tp) (valueSh kBf) (valueBf weakenId)
c = Alam lhs2 $ Abody $ case valueSh kBf of
TupRpair _ ix -> Compute $ mkBinary (PrimLtEq singleType) (paramsIn (TupRsingle scalarTypeInt) ix) (mkConstant (TupRsingle scalarTypeInt) 1)
_ -> error "Impossible pair"
g = Alam lhs2 $ Abody $ mkDefaultFoldStep2 (weaken (kBf .> kSh .> k1) f) argTmp (weaken (kBf .> kSh .> k1) output)
in
alet lhs1 (mkDefaultFoldStep1 f def input output)
$ alet (LeftHandSideWildcard $ desugarArrayR repr) (Awhile (shared $ desugarArrayR repr) c g $ value1 weakenId)
$ Return TupRunit
mkFold f def input@(ArgArray _ repr@(ArrayR shr tp) _ _) output = mkDefaultFoldSequential f def input output

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | ubuntu-latest-x64

Defined but not used: ‘repr’

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | ubuntu-latest-x64

Defined but not used: ‘shr’

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | ubuntu-latest-x64

Defined but not used: ‘tp’

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | windows-latest-x64

Defined but not used: `repr'

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | windows-latest-x64

Defined but not used: `shr'

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | windows-latest-x64

Defined but not used: `tp'

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | windows-latest-x64 ghc-9.4 release

Defined but not used: ‘repr’

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | windows-latest-x64 ghc-9.4 release

Defined but not used: ‘shr’

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | windows-latest-x64 ghc-9.4 release

Defined but not used: ‘tp’

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | macOS-latest-x64

Defined but not used: ‘repr’

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | macOS-latest-x64

Defined but not used: ‘shr’

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | macOS-latest-x64

Defined but not used: ‘tp’

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.4 release

Defined but not used: ‘repr’

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.4 release

Defined but not used: ‘shr’

Check warning on line 210 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.4 release

Defined but not used: ‘tp’

mkFoldSeg :: IntegralType i
-> Arg env (Fun' (e -> e -> e))
Expand Down Expand Up @@ -1178,125 +1162,41 @@ mkIntersect shr x y
mkIntersect' (ShapeRsnoc shr) (TupRpair s1 x1) (TupRpair s2 x2) = mkIntersect' shr s1 s2 `Pair` mkBinary (PrimMin singleType) (paramsIn' x1) (paramsIn' x2)

Check warning on line 1162 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | ubuntu-latest-x64

This binding for ‘shr’ shadows the existing binding

Check warning on line 1162 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | windows-latest-x64

This binding for `shr' shadows the existing binding

Check warning on line 1162 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | windows-latest-x64 ghc-9.4 release

This binding for ‘shr’ shadows the existing binding

Check warning on line 1162 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | macOS-latest-x64

This binding for ‘shr’ shadows the existing binding

Check warning on line 1162 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.4 release

This binding for ‘shr’ shadows the existing binding
mkIntersect' (ShapeRsnoc _) _ _ = error "Impossible pair"

-- Default implementation for the first step of a fold.
-- The output of the inner dimension is guaranteed to
-- be a power of two.
mkDefaultFoldStep1 :: forall benv op sh e. DesugarAcc op => Arg benv (Fun' (e -> e -> e)) -> Maybe (Arg benv (Exp' e)) -> Arg benv (In (sh, Int) e) -> Arg benv (Out sh e) -> OperationAcc op benv (DesugaredArrays (Array (sh, Int) e))
mkDefaultFoldStep1 (ArgFun f) def argIn@(ArgArray _ (ArrayR shr tp) (sh `TupRpair` n) _) argOut
| DeclareVars lhsTmp kTmp valueTmp <- declareVars $ buffersR tp
= let
lhsN1 = LeftHandSideSingle $ GroundRscalar scalarTypeInt
kN1 = weakenSucc weakenId

-- n-1 in case of fold1.
-- For fold, we have one additional element (the default value), and thus have (n+1)-1
nMinus1
| Just _ <- def = paramsIn (TupRsingle scalarType) n
| otherwise = mkBinary (PrimSub numType) (paramsIn (TupRsingle scalarType) n) (mkConstant (TupRsingle scalarTypeInt) 1)

shBase' = weakenVars (weakenSucc weakenId) sh
shTmp' = shBase' `TupRpair` TupRsingle (Var (GroundRscalar scalarTypeInt) ZeroIdx)

shTmp = weakenVars kTmp shTmp'
tmp = valueTmp weakenId
k = weakenSucc kTmp
argG = ArgFun $ mkDefaultFoldStep1Function (weakenArrayInstr k f) (weaken k <$> def) (weaken kTmp $ Var (GroundRscalar scalarTypeInt) ZeroIdx) (weaken k argIn)
argTmp = ArgArray Out (ArrayR shr tp) shTmp tmp
in
alet lhsN1 (Compute $ mkBinary (PrimBShiftL integralType) (mkConstant (TupRsingle scalarTypeInt) 1) $ mkLog2 nMinus1)
$ aletUnique lhsTmp (mkDefaultFoldAllocOrOutput (weaken kN1 argOut) $ groundToExpVar (shapeType shr) shTmp')
$ alet (LeftHandSideWildcard TupRunit) (mkGenerate argG argTmp)
$ Return (shTmp `TupRpair` tmp)

-- log_2(x) = 63 − clz(x) (for 64-bit integers)
mkLog2 :: OpenExp env benv Int -> OpenExp env benv Int
mkLog2 = mkBinary (PrimSub numType) (mkConstant (TupRsingle scalarTypeInt) 63) . PrimApp (PrimCountLeadingZeros integralType)

mkDefaultFoldStep1Function :: forall benv sh e. HasCallStack => Fun benv (e -> e -> e) -> Maybe (Arg benv (Exp' e)) -> GroundVar benv Int -> Arg benv (In (sh, Int) e) -> Fun benv ((sh, Int) -> e)
mkDefaultFoldStep1Function f def n1' argIn@(ArgArray _ (ArrayR (ShapeRsnoc shr) tp) (sh' `TupRpair` n') buffers)
| DeclareVars lhsX kX valueX <- declareVars $ shapeType shr
= let
x = expVars $ valueX (weakenSucc weakenId)
y = Evar $ Var scalarTypeInt ZeroIdx
y2 = mkBinary (PrimMul numType) y (mkConstant (TupRsingle scalarTypeInt) 2)
n1 = paramIn scalarTypeInt n1'

-- When a default or initial value is given, we just pretend that the input array is one
-- element larger, i.e., prefixed with that default value.
n
| Just _ <- def = mkBinary (PrimAdd numType) (paramsIn (TupRsingle scalarTypeInt) n') (mkConstant (TupRsingle scalarTypeInt) 1)
| otherwise = paramsIn (TupRsingle scalarTypeInt) n'

index' :: OpenExp env' benv sh -> OpenExp env' benv Int -> OpenExp env' benv e
index' x' y'
| Just (ArgExp d) <- def =
Cond (mkBinary (PrimEq singleType) y' $ mkConstant (TupRsingle scalarTypeInt) 0)
(weakenE weakenEmpty d)
(index (ArrayR (ShapeRsnoc shr) tp) (sh' `TupRpair` n') buffers $ Pair x' $ mkBinary (PrimSub numType) y' $ mkConstant (TupRsingle scalarTypeInt) 1)
index' x' y' =
(index (ArrayR (ShapeRsnoc shr) tp) (sh' `TupRpair` n') buffers $ Pair x' y')
in
-- \(x, y) ->
Lam (lhsX `LeftHandSidePair` LeftHandSideSingle scalarTypeInt)
$ Body
-- if (y < n-n1)
$ Cond (mkBinary (PrimLt singleType) y $ mkBinary (PrimSub numType) n n1)
-- then {reduce y*2 and y*2+1}
(apply2 tp f (index' x y2) (index' x $ mkBinary (PrimAdd numType) y2 $ mkConstant (TupRsingle scalarTypeInt) 1))
-- else {just copy the value from index (y+n-n1)}
(index' x $ mkBinary (PrimAdd numType) y $ mkBinary (PrimSub numType) n n1)

-- Halves the inner dimension of the array
mkDefaultFoldStep2 :: DesugarAcc op => Arg benv (Fun' (e -> e -> e)) -> Arg benv (In (sh, Int) e) -> Arg benv (Out sh e) -> OperationAcc op benv (DesugaredArrays (Array (sh, Int) e))
mkDefaultFoldStep2 (ArgFun f) argIn@(ArgArray _ (ArrayR shr@(ShapeRsnoc shr') tp) (TupRpair sh n) input) argOut
| DeclareVars lhsSh kSh valueSh <- declareVars $ TupRsingle $ GroundRscalar scalarTypeInt
, DeclareVars lhsTmp kTmp valueTmp <- declareVars $ buffersR tp
= let
shBase' = weakenVars kSh sh
shTmp' = shBase' `TupRpair` valueSh weakenId
mkDefaultFoldSequential :: forall benv op sh e. DesugarAcc op => Arg benv (Fun' (e -> e -> e)) -> Maybe (Arg benv (Exp' e)) -> Arg benv (In (sh, Int) e) -> Arg benv (Out sh e) -> OperationAcc op benv ()
mkDefaultFoldSequential op def argIn argOut = mkGenerate (mkDefaultFoldFunction op def argIn) argOut

shBase = weakenVars kTmp shBase'
shIn = shBase `TupRpair` weakenVars (kTmp .> kSh) n
shTmp = shBase `TupRpair` valueSh kTmp
mkDefaultFoldFunction :: Arg benv (Fun' (e -> e -> e)) -> Maybe (Arg benv (Exp' e)) -> Arg benv (In (sh, Int) e) -> Arg benv (Fun' (sh -> e))
mkDefaultFoldFunction (ArgFun op) def (ArgArray _ (ArrayR (ShapeRsnoc shr) tp) (sh `TupRpair` n) buffers)

Check warning on line 1169 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | ubuntu-latest-x64

Pattern match(es) are non-exhaustive

Check warning on line 1169 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | windows-latest-x64

Pattern match(es) are non-exhaustive

Check warning on line 1169 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | windows-latest-x64 ghc-9.4 release

Pattern match(es) are non-exhaustive

Check warning on line 1169 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | macOS-latest-x64

Pattern match(es) are non-exhaustive

Check warning on line 1169 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.4 release

Pattern match(es) are non-exhaustive
| DeclareVars lhsIdx k1 valueIdx <- declareVars $ shapeType shr

Check warning on line 1170 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | ubuntu-latest-x64

Defined but not used: ‘k1’

Check warning on line 1170 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | windows-latest-x64

Defined but not used: `k1'

Check warning on line 1170 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | windows-latest-x64 ghc-9.4 release

Defined but not used: ‘k1’

Check warning on line 1170 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / stack | macOS-latest-x64

Defined but not used: ‘k1’

Check warning on line 1170 in src/Data/Array/Accelerate/Trafo/Desugar.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.4 release

Defined but not used: ‘k1’
, DeclareVars lhsVal k2 valueVal <- declareVars tp =
let
initial = case def of
Nothing ->
Pair
(mkConstant (TupRsingle scalarTypeInt) 1)
(index (ArrayR (ShapeRsnoc shr) tp) (sh `TupRpair` n) buffers (expVars (valueIdx weakenId) `Pair` mkConstant (TupRsingle scalarTypeInt) 0))
Just (ArgExp d) ->
Pair
(mkConstant (TupRsingle scalarTypeInt) 0)
(weakenE weakenEmpty d)

temp = valueTmp weakenId
argG = weaken (kTmp .> kSh) $ ArgFun $ mkDefaultFoldStep2Function f argIn
argTmp = ArgArray Out (ArrayR shr tp) shTmp temp
lhs = LeftHandSidePair (LeftHandSideSingle scalarTypeInt) lhsVal
-- \(idx, accum)
step = Lam lhs $ Body $ Pair
-- (idx + 1
(mkBinary (PrimAdd numType) (Evar $ Var scalarTypeInt $ k2 >:> ZeroIdx) (mkConstant (TupRsingle scalarTypeInt) 1))
-- , op accum (input !! idx)
$ apply2 tp op (expVars $ valueVal weakenId)
$ index (ArrayR (ShapeRsnoc shr) tp) (sh `TupRpair` n) buffers
$ expVars (valueIdx $ weakenSucc k2) `Pair` Evar (Var scalarTypeInt $ k2 >:> ZeroIdx)

condition =
Lam (LeftHandSidePair (LeftHandSideSingle scalarTypeInt) (LeftHandSideWildcard tp))
$ Body $ mkBinary (PrimLt singleType) (Evar $ Var scalarTypeInt ZeroIdx) (paramsIn (TupRsingle scalarType) n)
in
alet lhsSh (Compute $ mkBinary (PrimRem integralType) (paramsIn (TupRsingle scalarTypeInt) n) (Const scalarTypeInt 2))
$ aletUnique lhsTmp (mkDefaultFoldAllocOrOutput (weaken kSh argOut) $ groundToExpVar (shapeType shr) shTmp')
$ alet (LeftHandSideWildcard TupRunit) (mkGenerate argG argTmp)
$ Return (shTmp `TupRpair` temp)

-- Allocates a new intermediate array or returns the output array.
-- If the inner dimension is 1, returns the output array, as we are in the last iteration.
-- Otherwise, allocates a new intermediate array.
--
mkDefaultFoldAllocOrOutput :: Arg benv (Out sh e) -> ExpVars benv (sh, Int) -> OperationAcc op benv (Buffers e)
mkDefaultFoldAllocOrOutput (ArgArray _ (ArrayR shr e) _ output) sh@(TupRpair _ y)
= Alet (LeftHandSideSingle $ GroundRscalar scalarType) (TupRsingle Shared) (Compute $ mkBinary (PrimEq singleType) (paramsIn' y) $ mkConstant (TupRsingle scalarTypeInt) 1)
$ Acond (Var scalarTypeWord8 ZeroIdx)
(Return $ weakenVars (weakenSucc weakenId) output)
(desugarAlloc (ArrayR (ShapeRsnoc shr) e) $ weakenVars (weakenSucc weakenId) sh)

-- \(x, y) -> f (a !! (x, 2*y)) (a !! (x, 2*y+1))
-- ==> \(x, y) -> let z = toIndex (x, 2*y) in f (a ! z) (a ! z+1)
mkDefaultFoldStep2Function :: forall benv sh e. HasCallStack => Fun benv (e -> e -> e) -> Arg benv (In (sh, Int) e) -> Fun benv ((sh, Int) -> e)
mkDefaultFoldStep2Function f (ArgArray _ (ArrayR (ShapeRsnoc shr) tp) sh input)
| DeclareVars lhsX kX valueX <- declareVars $ shapeType shr
, DeclareVars lhsY kY valueY <- declareVars $ TupRsingle scalarTypeInt
-- \(x, y) ->
= Lam (lhsX `LeftHandSidePair` lhsY)
$ Body
-- let z = toIndex (x, 2*y)
$ Let (LeftHandSideSingle scalarTypeInt) (ToIndex (ShapeRsnoc shr) (paramsIn (shapeType $ ShapeRsnoc shr) sh) (expVars (valueX kY) `Pair` mkBinary (PrimMul numType) (expVars $ valueY weakenId) (mkConstant (TupRsingle scalarTypeInt) 2)))
-- f
$ apply2 tp f
-- (a ! z)
(linearIndex tp input $ {- z -} Var scalarTypeInt ZeroIdx)
-- (let w = z + 1 in a ! w)
(Let (LeftHandSideSingle scalarTypeInt) (mkBinary (PrimAdd numType) (Evar $ Var scalarTypeInt ZeroIdx) (mkConstant (TupRsingle scalarTypeInt) 1))
$ linearIndex tp input $ Var scalarTypeInt ZeroIdx)
ArgFun $ Lam lhsIdx $ Body
$ Let lhs (While condition step initial)
$ expVars $ valueVal weakenId

-- In case of a scan with a default value, prepends the initial value before the other elements
-- The default value is placed as the first value in case of a left-to-right scan, or as the
Expand Down
2 changes: 1 addition & 1 deletion src/Data/Array/Accelerate/Trafo/Schedule/Uniform/Future.hs
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ loopFuture resolved (FutureBuffer tp ref (Move readLockSignal) (Just (Move write
$ Just $ Borrow (Just signalW) resolverR
| otherwise -> internalError "input or output impossible"
}
loopFuture resolved (FutureBuffer tp ref (Lock readLockSignal readLockResolver) (Just (Lock writeLockSignal writeLockResolver))) = undefined
loopFuture resolved (FutureBuffer tp ref (Lock readLockSignal readLockResolver) (Just (Lock writeLockSignal writeLockResolver))) =
-- A borrowed writable buffer
-- We must add two signals (and accompanying signal resolvers) to the state
-- to synchronize read and write access. Furthermore we need to declare two
Expand Down
Loading

0 comments on commit 3eb258e

Please sign in to comment.