Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dpvanbalen committed Jun 27, 2024
1 parent ce1db2b commit 11512d6
Showing 1 changed file with 26 additions and 32 deletions.
58 changes: 26 additions & 32 deletions src/Data/Array/Accelerate/AST/Partitioned.hs
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}

{-# LANGUAGE TypeFamilyDependencies#-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wno-orphans #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}

-- |
-- Module : Data.Array.Accelerate.AST.Partitioned
Expand All @@ -31,14 +29,12 @@
module Data.Array.Accelerate.AST.Partitioned (
module Data.Array.Accelerate.AST.Operation,
module Data.Array.Accelerate.AST.Partitioned,
GroundR(..), GroundsR, GroundVar, GroundVars, NFData'(..), Arg(..),
GroundR(..), NFData'(..), Arg(..),
AccessGroundR(..),
PreArgs(..), Args, Modifier(..),
Exp', Var', Fun', In, Out, Mut,
flattenClustered, flattenCluster,
PreArgs(..), Modifier(..),
) where

import Data.Array.Accelerate.AST.Idx

import Data.Array.Accelerate.AST.Operation hiding (OperationAcc, OperationAfun)

import Prelude hiding ( take )
Expand All @@ -50,21 +46,18 @@ import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.Representation.Shape (ShapeR (..), shapeType, typeShape)
import Data.Type.Equality
import Unsafe.Coerce (unsafeCoerce)
import Data.Array.Accelerate.Representation.Type (TypeR, TupR (..), mapTupR, Distribute)
import Data.Array.Accelerate.Representation.Type (TypeR, TupR (..), mapTupR)
import Data.Array.Accelerate.Type (ScalarType (..), SingleType (..), NumType (..), IntegralType (..))
import Data.Array.Accelerate.AST.Environment (Env (..), prj')
import Data.Functor.Identity

import Data.Array.Accelerate.Trafo.Partitioning.ILP.Labels (Labels, LabelledArgs, LabelledArg (..), ALabel (..), ALabels (..), ELabel (..), Label)
import Data.List (nub, sortOn)
import Lens.Micro (_1)

import Data.Array.Accelerate.Trafo.Partitioning.ILP.Labels (Labels, LabelledArgs, LabelledArg (..), ALabel (..), ELabel (..), Label)
import Data.List (sortOn)
import qualified Data.Functor.Const as C
import Data.Array.Accelerate.Trafo.Partitioning.ILP.Graph (LabelledArgOp (..), BackendClusterArg, MakesILP (..), LabelledArgsOp, unOpLabels, BackendCluster)
import Data.Array.Accelerate.Trafo.Partitioning.ILP.Graph (LabelledArgOp (..), BackendClusterArg, MakesILP (..), LabelledArgsOp, BackendCluster)
import Data.Array.Accelerate.Trafo.Operation.LiveVars
import Data.Maybe (fromJust, Maybe (Nothing))
import Data.Maybe (fromJust)
import Data.Array.Accelerate.AST.Var (varsType)
import qualified Debug.Trace
import Data.Array.Accelerate.Error (HasCallStack)
import Data.Array.Accelerate.Analysis.Match (matchShapeR)


Expand Down Expand Up @@ -96,6 +89,7 @@ data SingleOp op args where
pattern Op :: SLVOp op args -> Label -> Cluster op args
pattern Op slv l <- SingleOp (toOld -> slv) l where
Op (SLV (SOp (SOAOp op soas) sortedargs) subargs) l = SingleOp (Single op soas sortedargs subargs) l
toOld :: SingleOp op args -> SLVOp op args
toOld (Single op soas sortedargs subargs) = SLV (SOp (SOAOp op soas) sortedargs) subargs

data SLVOp op args where
Expand Down Expand Up @@ -154,7 +148,7 @@ deriving instance Show (Fusion l r total)


soaShrink :: forall args expanded f
.
.
-- (forall a. Show (f a))
-- =>
(forall l r g. f (g l) -> f (g r) -> f (g (l,r)))
Expand Down Expand Up @@ -189,13 +183,13 @@ split (ArgArray Out (ArrayR shr (TupRpair rl rr)) sh (TupRpair bufl bufr)) = (Ar
split _ = error "non-array soa"

splitLabelledArgs :: LabelledArg env (f (l,r)) -> (LabelledArg env (f l), LabelledArg env (f r))
splitLabelledArgs (L _ ((Arr (TupRsingle _), _))) = error "pair in single"
splitLabelledArgs (L arg ((Arr (TupRpair labl labr), labs))) = bimap (`L` ((Arr labl, labs))) (`L` ((Arr labr, labs))) $ split arg
splitLabelledArgs (L _ ((NotArr, _))) = error "SOA'd non-array arg"
splitLabelledArgs (L _ (Arr (TupRsingle _), _)) = error "pair in single"
splitLabelledArgs (L arg (Arr (TupRpair labl labr), labs)) = bimap (`L` (Arr labl, labs)) (`L` (Arr labr, labs)) $ split arg
splitLabelledArgs (L _ (NotArr, _)) = error "SOA'd non-array arg"
splitLabelledArgsOp :: LabelledArgOp op env (f (l,r)) -> (LabelledArgOp op env (f l), LabelledArgOp op env (f r))
splitLabelledArgsOp (LOp _ ((Arr (TupRsingle _), _)) b) = error "pair in single"
splitLabelledArgsOp (LOp arg ((Arr (TupRpair labl labr), labs)) b) = bimap (flip (`LOp` ((Arr labl, labs))) b) (flip (`LOp` ((Arr labr, labs))) b) $ split arg
splitLabelledArgsOp (LOp _ ((NotArr, _)) _) = error "SOA'd non-array arg"
splitLabelledArgsOp (LOp _ (Arr (TupRsingle _), _) b) = error "pair in single"

Check warning on line 190 in src/Data/Array/Accelerate/AST/Partitioned.hs

View workflow job for this annotation

GitHub Actions / stack | ubuntu-latest-x64

Defined but not used: ‘b’

Check warning on line 190 in src/Data/Array/Accelerate/AST/Partitioned.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.4 release

Defined but not used: ‘b’

Check warning on line 190 in src/Data/Array/Accelerate/AST/Partitioned.hs

View workflow job for this annotation

GitHub Actions / stack | windows-latest-x64

Defined but not used: `b'

Check warning on line 190 in src/Data/Array/Accelerate/AST/Partitioned.hs

View workflow job for this annotation

GitHub Actions / cabal | windows-latest-x64 ghc-9.4 release

Defined but not used: ‘b’

Check warning on line 190 in src/Data/Array/Accelerate/AST/Partitioned.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.4 release

Defined but not used: ‘b’

Check warning on line 190 in src/Data/Array/Accelerate/AST/Partitioned.hs

View workflow job for this annotation

GitHub Actions / stack | macOS-latest-x64

Defined but not used: ‘b’
splitLabelledArgsOp (LOp arg (Arr (TupRpair labl labr), labs) b) = bimap (flip (`LOp` (Arr labl, labs)) b) (flip (`LOp` (Arr labr, labs)) b) $ split arg
splitLabelledArgsOp (LOp _ (NotArr, _) _) = error "SOA'd non-array arg"

soaOut :: forall env args expanded f. (forall sh l r. f (Value sh (l,r)) -> (f (Value sh l),f (Value sh r))) -> Args env args -> SOAs args expanded -> Env f (OutArgs args) -> Env f (OutArgs expanded)
soaOut _ ArgsNil SOArgsNil Empty = Empty
Expand Down Expand Up @@ -445,7 +439,7 @@ fuse' :: MakesILP op => LabelledArgsOp op env l -> LabelledArgsOp op env r -> Pr
-> (forall sh e. f (Out sh e) -> f (In sh e) -> f (Var' sh))
-> (forall args. PreArgs f args -> Cluster op args -> result)
-> result
fuse' labl labr largs rargs l r c k =
fuse' labl labr largs rargs l r c k =
mkFused labl labr $ \f -> k (both c f largs rargs) (Fused f l r)

mkFused :: MakesILP op => LabelledArgsOp op env l -> LabelledArgsOp op env r -> (forall args. Fusion l r args -> result) -> result
Expand Down Expand Up @@ -518,7 +512,7 @@ sortArgs args k =
keepAll (_:>:as) = SubArgKeep `SubArgsLive` keepAll as
-- If it's a buffer, we only care about its unique label. If it's not a buffer, the other labels suffice to give any ordering.
getElabelForSort :: LabelledArg env a -> Either ELabel Labels
getElabelForSort (L (ArgArray m (ArrayR _ TupRsingle{}) _ _) (Arr (TupRsingle (C.Const e)),_))
getElabelForSort (L (ArgArray m (ArrayR _ TupRsingle{}) _ _) (Arr (TupRsingle (C.Const e)),_))
| In <- m = Left e
| Out <- m = Left e
getElabelForSort (L _ (_,ls)) = Right ls
Expand Down Expand Up @@ -554,11 +548,11 @@ mkSOA (ArgArray Mut _ _ _) k = k SOArgSingle
mkSOA _ _ = error "pair or unit in a tuprsingle somewhere"

instance ShrinkArg (BackendClusterArg op) => SLVOperation (Clustered op) where
slvOperation (Clustered cluster b) = Just $ ShrinkOperation $ \ff' a' a ->
slvOperation (Clustered cluster b) = Just $ ShrinkOperation $ \ff' a' a ->
case slvCluster cluster ff' a' a of
ShrunkOperation' cluster' args ->
ShrunkOperation (Clustered cluster' $ shrinkArgs ff' b) args

instance SimplifyOperation (Clustered op)
-- Default implementation, where detectCopy always returns []

Expand All @@ -579,7 +573,7 @@ slvCluster (Op op label) sub args' args

slvCluster (Fused fusion left right) sub args1' args1 = splitslvstuff fusion sub args1' args1 $
\f' lsub largs' largs rsub rargs' rargs -> case (slvCluster left lsub largs' largs, slvCluster right rsub rargs' rargs) of
(ShrunkOperation' lop largs'', ShrunkOperation' rop rargs'') ->
(ShrunkOperation' lop largs'', ShrunkOperation' rop rargs'') ->
ShrunkOperation' (Fused f' lop rop) (both (\x _ -> outvar x) f' largs'' rargs'')
where
splitslvstuff :: Fusion l r a
Expand Down Expand Up @@ -666,4 +660,4 @@ showSorted :: LabelledArgsOp op env args -> String
showSorted ArgsNil = ""
showSorted (a :>: args) = case a of
LOp (ArgArray m _ _ _) (_,ls) _ -> show m <> "{" <> show ls <> "}" <> showSorted args
_ -> showSorted args
_ -> showSorted args

0 comments on commit 11512d6

Please sign in to comment.