From f82ed093ec7f91c85f0b3163e262df1576e622fc Mon Sep 17 00:00:00 2001 From: David van Balen Date: Mon, 15 Apr 2024 14:56:18 +0200 Subject: [PATCH] slv on partitionedacc --- accelerate.cabal | 34 +-- src/Data/Array/Accelerate/AST/Partitioned.hs | 280 ++++++++++-------- src/Data/Array/Accelerate/Eval.hs | 62 ++-- .../Array/Accelerate/Pretty/Partitioned.hs | 9 +- .../Accelerate/Trafo/Operation/LiveVars.hs | 1 + .../Trafo/Partitioning/ILP/Clustering.hs | 8 +- .../Trafo/Partitioning/ILP/Graph.hs | 4 +- 7 files changed, 231 insertions(+), 167 deletions(-) diff --git a/accelerate.cabal b/accelerate.cabal index 59f49d7d6..7e8d581eb 100644 --- a/accelerate.cabal +++ b/accelerate.cabal @@ -144,15 +144,15 @@ extra-source-files: cbits/xkcp/*.inc -- TRACY -- These are referenced directly using the FFI - cbits/tracy/public/*.cpp - cbits/tracy/public/tracy/*.h - cbits/tracy/public/tracy/*.hpp - cbits/tracy/public/common/*.h - cbits/tracy/public/common/*.hpp - cbits/tracy/public/common/*.cpp - cbits/tracy/public/client/*.h - cbits/tracy/public/client/*.hpp - cbits/tracy/public/client/*.cpp + -- cbits/tracy/public/*.cpp + -- cbits/tracy/public/tracy/*.h + -- cbits/tracy/public/tracy/*.hpp + -- cbits/tracy/public/common/*.h + -- cbits/tracy/public/common/*.hpp + -- cbits/tracy/public/common/*.cpp + -- cbits/tracy/public/client/*.h + -- cbits/tracy/public/client/*.hpp + -- cbits/tracy/public/client/*.cpp -- These are used to build Tracy's client tools in Setup.hs cbits/tracy/capture/build/unix/Makefile cbits/tracy/capture/build/unix/*.mk @@ -176,12 +176,12 @@ extra-source-files: cbits/tracy/profiler/src/*.cpp cbits/tracy/profiler/src/*.h cbits/tracy/profiler/src/*.hpp - cbits/tracy/profiler/src/font/*.hpp - cbits/tracy/profiler/src/imgui/*.cpp - cbits/tracy/profiler/src/imgui/*.h - cbits/tracy/public/libbacktrace/*.cpp - cbits/tracy/public/libbacktrace/*.h - cbits/tracy/public/libbacktrace/*.hpp + -- cbits/tracy/profiler/src/font/*.hpp + -- cbits/tracy/profiler/src/imgui/*.cpp + -- cbits/tracy/profiler/src/imgui/*.h + -- cbits/tracy/public/libbacktrace/*.cpp + -- cbits/tracy/public/libbacktrace/*.h + -- cbits/tracy/public/libbacktrace/*.hpp cbits/tracy/server/*.cpp cbits/tracy/server/*.h cbits/tracy/server/*.hpp @@ -190,7 +190,7 @@ extra-source-files: cbits/tracy/zstd/common/*.h cbits/tracy/zstd/compress/*.c cbits/tracy/zstd/compress/*.h - cbits/tracy/zstd/decompress/*.S + -- cbits/tracy/zstd/decompress/*.S cbits/tracy/zstd/decompress/*.c cbits/tracy/zstd/decompress/*.h cbits/tracy/zstd/dictBuilder/*.c @@ -451,6 +451,7 @@ library Data.Array.Accelerate.Trafo.LiveVars Data.Array.Accelerate.Trafo.NewNewFusion Data.Array.Accelerate.Trafo.Operation.Substitution + Data.Array.Accelerate.Trafo.Operation.LiveVars Data.Array.Accelerate.Trafo.Partitioning.ILP Data.Array.Accelerate.Trafo.Partitioning.ILP.Clustering Data.Array.Accelerate.Trafo.Partitioning.ILP.Graph @@ -528,7 +529,6 @@ library Data.Array.Accelerate.Trafo.Exp.Algebra Data.Array.Accelerate.Trafo.Environment Data.Array.Accelerate.Trafo.Operation.Simplify - Data.Array.Accelerate.Trafo.Operation.LiveVars Data.Array.Accelerate.Trafo.Shrink Data.Atomic diff --git a/src/Data/Array/Accelerate/AST/Partitioned.hs b/src/Data/Array/Accelerate/AST/Partitioned.hs index 6bc7bac2b..34df7dd70 100644 --- a/src/Data/Array/Accelerate/AST/Partitioned.hs +++ b/src/Data/Array/Accelerate/AST/Partitioned.hs @@ -45,7 +45,7 @@ import Data.Bifunctor import Data.Array.Accelerate.Trafo.Desugar (ArrayDescriptor(..)) import Data.Array.Accelerate.Representation.Array (Array, Buffers, ArrayR (..)) import Data.Array.Accelerate.AST.LeftHandSide -import Data.Array.Accelerate.Representation.Shape (ShapeR (..), shapeType) +import Data.Array.Accelerate.Representation.Shape (ShapeR (..), shapeType, typeShape) import Data.Type.Equality import Unsafe.Coerce (unsafeCoerce) import Data.Array.Accelerate.Representation.Type (TypeR, TupR (..), mapTupR, Distribute) @@ -64,33 +64,25 @@ import Data.Array.Accelerate.AST.Var (varsType) import qualified Debug.Trace -slv :: (forall sh e. f (Out sh e) -> f (Var' sh)) -> SubArgs args args' -> PreArgs f args -> PreArgs f args' -slv _ SubArgsNil ArgsNil = ArgsNil -slv f (SubArgsDead sas) (arg:>:args) = f arg :>: slv f sas args -slv f (SubArgsLive SubArgKeep sas) (arg:>:args) = arg :>: slv f sas args -slv _ _ _ = error "not soa'ed" -slv' :: (forall sh e. f (Var' sh) -> f (Out sh e)) -> SubArgs args args' -> PreArgs f args' -> PreArgs f args -slv' _ SubArgsNil ArgsNil = ArgsNil -slv' f (SubArgsDead sas) (arg:>:args) = f arg :>: slv' f sas args -slv' f (SubArgsLive SubArgKeep sas) (arg:>:args) = arg :>: slv' f sas args -slv' _ _ _ = error "not soa'ed" -slvIn :: (forall sh e. f (Var' sh) -> f (Sh sh e)) -> SubArgs args args' -> Env f (InArgs args') -> Env f (InArgs args) -slvIn _ SubArgsNil Empty = Empty -slvIn f (SubArgsDead sas) (Push env x) = Push (slvIn f sas env) (f x) -slvIn f (SubArgsLive SubArgKeep sas) (Push env x) = Push (slvIn f sas env) x -slvIn _ _ _ = error "not soa'ed" -slvOut :: Args env args' -> SubArgs args args' -> Env f (OutArgs args) -> Env f (OutArgs args') -slvOut _ SubArgsNil Empty = Empty -slvOut (_:>:args) (SubArgsDead sas) (Push env _) = slvOut args sas env -slvOut (a :>: args) (SubArgsLive SubArgKeep sas) env = case a of - ArgArray Out _ _ _ - | Push env' x <- env -> Push (slvOut args sas env') x - ArgArray In _ _ _ -> slvOut args sas env - ArgArray Mut _ _ _ -> slvOut args sas env - ArgVar _ -> slvOut args sas env - ArgFun _ -> slvOut args sas env - ArgExp _ -> slvOut args sas env -slvOut _ _ _ = error "not soa'ed" + + +type PartitionedAcc op = PreOpenAcc (Clustered op) +type PartitionedAfun op = PreOpenAfun (Clustered op) + + +data Clustered op args = Clustered (Cluster op args) (BackendCluster op args) + +data Cluster op args where + Op :: SLVOp op args -> Label -> Cluster op args + Fused :: Fusion largs rargs args + -> Cluster op largs + -> Cluster op rargs + -> Cluster op args + +data SLVOp op args where + SLV :: SortedOp op big + -> SubArgs big small + -> SLVOp op small -- a wrapper around operations which sorts the (term and type level) arguments on global labels, to line the arguments up for Fusion data SortedOp op args where @@ -117,6 +109,32 @@ data SOA arg appendto result where SOArgSingle :: SOA arg args (arg -> args) SOArgTup :: SOA (f right) args args' -> SOA (f left) args' args'' -> SOA (f (left,right)) args args'' + +-- These correspond to the inference rules in the paper +-- datatype describing how to combine the arguments of two fused clusters +data Fusion largs rars args where + EmptyF :: Fusion () () () + Vertical :: ArrayR (Array sh e) + -> Fusion l r a + -> Fusion (Out sh e -> l) (In sh e -> r) (Var' sh -> a) + Horizontal :: Fusion l r a + -> Fusion (In sh e -> l) (In sh e -> r) (In sh e -> a) + Diagonal :: Fusion l r a + -> Fusion (Out sh e -> l) (In sh e -> r) (Out sh e -> a) + IntroI1 :: Fusion l r a + -> Fusion (In sh e -> l) r (In sh e -> a) + IntroI2 :: Fusion l r a + -> Fusion l (In sh e -> r) (In sh e -> a) + IntroO1 :: Fusion l r a + -> Fusion (Out sh e -> l) r (Out sh e -> a) + IntroO2 :: Fusion l r a + -> Fusion l (Out sh e -> r) (Out sh e -> a) + -- not in the paper; not meant for array args: + IntroL :: Fusion l r a -> Fusion (x -> l) r (x -> a) + IntroR :: Fusion l r a -> Fusion l (x -> r) (x -> a) +deriving instance Show (Fusion l r total) + + soaShrink :: forall args expanded f . (forall a. Show (f a)) => (forall l r g. f (g l) -> f (g r) -> f (g (l,r))) @@ -202,37 +220,7 @@ justOut (ArgFun _ :>: args) (_ :>: fs) = justOut args fs justOut (ArgArray In _ _ _ :>: args) (_ :>: fs) = justOut args fs justOut (ArgArray Mut _ _ _ :>: args) (_ :>: fs) = justOut args fs -data Cluster op args where - Op :: SortedOp op args -> Label -> Cluster op args - Fused :: Fusion largs rargs args - -> Cluster op largs - -> Cluster op rargs - -> Cluster op args - --- These correspond to the inference rules in the paper --- datatype describing how to combine the arguments of two fused clusters -data Fusion largs rars args where - EmptyF :: Fusion () () () - Vertical :: ArrayR (Array sh e) - -> Fusion l r a - -> Fusion (Out sh e -> l) (In sh e -> r) (Var' sh -> a) - Horizontal :: Fusion l r a - -> Fusion (In sh e -> l) (In sh e -> r) (In sh e -> a) - Diagonal :: Fusion l r a - -> Fusion (Out sh e -> l) (In sh e -> r) (Out sh e -> a) - IntroI1 :: Fusion l r a - -> Fusion (In sh e -> l) r (In sh e -> a) - IntroI2 :: Fusion l r a - -> Fusion l (In sh e -> r) (In sh e -> a) - IntroO1 :: Fusion l r a - -> Fusion (Out sh e -> l) r (Out sh e -> a) - IntroO2 :: Fusion l r a - -> Fusion l (Out sh e -> r) (Out sh e -> a) - -- not in the paper; not meant for array args: - IntroL :: Fusion l r a -> Fusion (x -> l) r (x -> a) - IntroR :: Fusion l r a -> Fusion l (x -> r) (x -> a) -deriving instance Show (Fusion l r total) left :: Fusion largs rargs args -> Args env args -> Args env largs left = left' (\(ArgVar sh) -> ArgArray Out (ArrayR (varsToShapeR sh) er) (mapTupR (\(Var t ix)->Var (GroundRscalar t) ix) sh) er) @@ -285,13 +273,6 @@ both k (IntroL f) (l:>:ls) rs = l :>: both k f ls rs both k (IntroR f) ls (r:>:rs) = r :>: both k f ls rs -type PartitionedAcc op = PreOpenAcc (Clustered op) -type PartitionedAfun op = PreOpenAfun (Clustered op) - - - -data Clustered op args = Clustered (Cluster op args) (BackendCluster op args) - varsToShapeR :: Vars ScalarType g sh -> ShapeR sh varsToShapeR = typeRtoshapeR . varsType @@ -467,21 +448,25 @@ addboth _ _ _ _ = error "fusing non-arrays" singleton :: MakesILP op => Label -> LabelledArgsOp op env args -> op args -> (forall args'. Clustered op args' -> r) -> r singleton l largs op k = mkSOAs (mapArgs (\(LOp a _ _) -> a) largs) $ \soas -> - sortArgs (soaExpand splitLabelledArgs soas (unOpLabels' largs)) $ \sa@(SA sort _) -> - k $ Clustered (Op (SOp (SOAOp op soas) sa) l) (mapArgs getClusterArg $ sort $ soaExpand splitLabelledArgsOp soas largs) + sortArgs (soaExpand splitLabelledArgs soas (unOpLabels' largs)) $ \sa@(SA sort _) slv -> + k $ Clustered (Op (SLV (SOp (SOAOp op soas) sa) slv) l) (mapArgs getClusterArg $ sort $ soaExpand splitLabelledArgsOp soas largs) -- (subargsId $ sort $ soaExpand splitLabelledArgsOp soas largs) -sortArgs :: LabelledArgs env args -> (forall sorted. SortedArgs args sorted -> r) -> r +sortArgs :: LabelledArgs env args -> (forall sorted. SortedArgs args sorted -> SubArgs sorted sorted -> r) -> r sortArgs args k = -- if nub ls /= ls then error "not sure if this works" -- this means that some arguments have identical sets of labels. This should only be a problem if two array arguments share labelsets. -- else - k $ SA - (\args -> case argsFromList . map snd . sortOn fst . zip ls . argsToList $ args of Exists a -> unsafeCoerce a) - (\srts -> case argsFromList . map snd . sortOn fst . zip ls' . argsToList $ srts of Exists a -> unsafeCoerce a) + k (SA + (\args -> case argsFromList . map snd . sortOn fst . zip ls . argsToList $ args of Exists a -> unsafeCoerce a) + (\srts -> case argsFromList . map snd . sortOn fst . zip ls' . argsToList $ srts of Exists a -> unsafeCoerce a)) + (unsafeCoerce $ keepAll args) -- same length, so coerce is safe where args' = argsToList args ls = map (\(Exists (L _ (_,l)))->l) args' ls' = map snd $ sortOn fst $ zip ls [1..] + keepAll :: LabelledArgs env args -> SubArgs args args + keepAll ArgsNil = SubArgsNil + keepAll (_:>:as) = SubArgKeep `SubArgsLive` keepAll as subargsId :: PreArgs f args -> SubArgs args args subargsId ArgsNil = SubArgsNil @@ -510,63 +495,110 @@ mkSOA (ArgFun _) k = k SOArgSingle mkSOA (ArgArray Mut _ _ _) k = k SOArgSingle mkSOA _ _ = error "pair or unit in a tuprsingle somewhere" -instance SLVOperation (Clustered op) where - slvOperation = const Nothing - -outvar :: Arg env (Out sh e) -> Arg env (Var' sh) -outvar (ArgArray Out (ArrayR shr _) sh _) = ArgVar $ groundToExpVar (shapeType shr) sh +instance ShrinkArg (BackendClusterArg op) => SLVOperation (Clustered op) where + slvOperation (Clustered cluster b) = Just $ ShrinkOperation $ \ff' a' a -> + case slvCluster cluster ff' a' a of + ShrunkOperation' cluster' args -> + ShrunkOperation (Clustered cluster' $ shrinkArgs ff' b) args + +-- instance SLVOperation (Cluster op) where +-- slvOperation cluster = -- Nothing +-- Just $ ShrinkOperation $ \sub args' args -> +-- case slvCluster cluster sub args' args of +-- ShrunkOperation' cluster' args'' -> ShrunkOperation cluster' args'' + +slvCluster :: Cluster op f -> SubArgs f f' -> Args env' f' -> Args env f -> ShrunkOperation' (Cluster op) env' f' +slvCluster (Op op label) sub args' args + -- | ShrunkOperation' op' subargs <- undefined op sub args' + | op' <- slvSLVOp op sub + = ShrunkOperation' (Op op' label) args' + where + slvSLVOp :: SLVOp op big -> SubArgs big small -> SLVOp op small + slvSLVOp (SLV op sa1) sa2 = SLV op (composeSubArgs sa1 sa2) -instance SLVOperation (Cluster op) where - slvOperation cluster = Nothing --- Just $ ShrinkOperation $ \sub args' args -> --- case slvCluster cluster sub args' args of --- ShrunkOperation' cluster' args'' -> ShrunkOperation cluster' args'' - --- slvCluster :: Cluster op f -> SubArgs f f' -> Args env' f' -> Args env f -> ShrunkOperation' (Cluster op) env' f' --- slvCluster (Op op label) sub args' _ --- | ShrunkOperation' op' subargs <- slvSLVedOp op sub args' --- = ShrunkOperation' (Op op' label) subargs --- slvCluster (Fused fusion left right) sub args1' args1 = splitslvstuff fusion sub args1' args1 $ --- \f' lsub largs' largs rsub rargs' rargs -> case (slvCluster left lsub largs' largs, slvCluster right rsub rargs' rargs) of --- (ShrunkOperation' lop largs'', ShrunkOperation' rop rargs'') -> --- ShrunkOperation' (Fused f' lop rop) (both (\x _ -> outvar x) f' largs'' rargs'') --- where --- splitslvstuff :: Fusion l r a --- -> SubArgs a a' --- -> Args env' a' --- -> Args env a --- -> (forall l' r'. Fusion l' r' a' -> SubArgs l l' -> Args env' l' -> Args env l -> SubArgs r r' -> Args env' r' -> Args env r -> result) --- -> result --- splitslvstuff EmptyF SubArgsNil ArgsNil ArgsNil k = k EmptyF SubArgsNil ArgsNil ArgsNil SubArgsNil ArgsNil ArgsNil --- splitslvstuff f (SubArgsLive (SubArgOut SubTupRskip) subs) args' args k = error "completely removed out arg using subtupr" --splitslvstuff f (SubArgsDead subs) args' args k --- splitslvstuff f (SubArgsLive (SubArgOut SubTupRkeep) subs) args' args k = splitslvstuff f (SubArgsLive SubArgKeep subs) args' args k --- splitslvstuff f (SubArgsLive (SubArgOut SubTupRpair{}) subs) (arg':>:args') (arg:>:args) k = error "not SOA'd array" --- splitslvstuff (Diagonal f) (SubArgsDead subs) args' (arg@(ArgArray _ r sh _):>:args) k = splitslvstuff (Vertical r f) (SubArgsLive SubArgKeep subs) args' (ArgVar (groundToExpVar (shapeType $ arrayRshape r) sh) :>:args) k --- splitslvstuff (IntroO1 f) (SubArgsDead subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroL f) (SubArgsDead lsubs) (arg':>:largs') (arg:>:largs) rsubs rargs' rargs --- splitslvstuff (IntroO2 f) (SubArgsDead subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroR f) lsubs largs' largs (SubArgsDead rsubs) (arg':>:rargs') (arg:>:rargs) --- splitslvstuff (IntroL f) (SubArgsDead subs) (arg':>:args') (arg:>:args) k = error "out in IntroL/R" --- splitslvstuff (IntroR f) (SubArgsDead subs) (arg':>:args') (arg:>:args) k = error "out in IntroL/R" --- splitslvstuff (Vertical r f) (SubArgsLive SubArgKeep subs) (ArgVar arg':>:args') (ArgVar arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (Vertical r f) (SubArgsLive SubArgKeep lsubs) (ArgArray Out r sh' buf :>:largs') (ArgArray Out r sh buf :>:largs) (SubArgsLive SubArgKeep rsubs) (ArgArray In r sh' buf :>:rargs') (ArgArray In r sh buf :>:rargs) --- where --- buf = error "fused away buffer" --- sh = expToGroundVar arg --- sh' = expToGroundVar arg' --- splitslvstuff (Diagonal f) (SubArgsLive SubArgKeep subs) (arg'@(ArgArray Out r' sh' buf'):>:args') (arg@(ArgArray Out r sh buf):>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (Diagonal f) (SubArgsLive SubArgKeep lsubs) (arg':>:largs') (arg:>:largs) (SubArgsLive SubArgKeep rsubs) (ArgArray In r' sh' buf':>:rargs') (ArgArray In r sh buf:>:rargs) --- splitslvstuff (Horizontal f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (Horizontal f) (SubArgsLive SubArgKeep lsubs) ( arg':>:largs') ( arg:>:largs) (SubArgsLive SubArgKeep rsubs) ( arg':>:rargs') ( arg:>:rargs) --- splitslvstuff (IntroI1 f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroI1 f) (SubArgsLive SubArgKeep lsubs) ( arg':>:largs') ( arg:>:largs) rsubs rargs' rargs --- splitslvstuff (IntroI2 f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroI2 f) lsubs largs' largs (SubArgsLive SubArgKeep rsubs) ( arg':>:rargs') ( arg:>:rargs) --- splitslvstuff (IntroO1 f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroO1 f) (SubArgsLive SubArgKeep lsubs) ( arg':>:largs') ( arg:>:largs) rsubs rargs' rargs --- splitslvstuff (IntroO2 f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroO2 f) lsubs largs' largs (SubArgsLive SubArgKeep rsubs) ( arg':>:rargs') ( arg:>:rargs) --- splitslvstuff (IntroL f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroL f) (SubArgsLive SubArgKeep lsubs) ( arg':>:largs') ( arg:>:largs) rsubs rargs' rargs --- splitslvstuff (IntroR f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroR f) lsubs largs' largs (SubArgsLive SubArgKeep rsubs) ( arg':>:rargs') ( arg:>:rargs) - - --- -- Variant of ShrunkOperation where f is not an existential --- data ShrunkOperation' op env f where --- ShrunkOperation' :: op f -> Args env f -> ShrunkOperation' op env f +slvCluster (Fused fusion left right) sub args1' args1 = splitslvstuff fusion sub args1' args1 $ + \f' lsub largs' largs rsub rargs' rargs -> case (slvCluster left lsub largs' largs, slvCluster right rsub rargs' rargs) of + (ShrunkOperation' lop largs'', ShrunkOperation' rop rargs'') -> + ShrunkOperation' (Fused f' lop rop) (both (\x _ -> outvar x) f' largs'' rargs'') + where + splitslvstuff :: Fusion l r a + -> SubArgs a a' + -> Args env' a' + -> Args env a + -> (forall l' r'. Fusion l' r' a' -> SubArgs l l' -> Args env' l' -> Args env l -> SubArgs r r' -> Args env' r' -> Args env r -> result) + -> result + splitslvstuff EmptyF SubArgsNil ArgsNil ArgsNil k = k EmptyF SubArgsNil ArgsNil ArgsNil SubArgsNil ArgsNil ArgsNil + splitslvstuff f (SubArgsLive (SubArgOut SubTupRskip) subs) args' args k = error "completely removed out arg using subtupr" --splitslvstuff f (SubArgsDead subs) args' args k + splitslvstuff f (SubArgsLive (SubArgOut SubTupRkeep) subs) args' args k = splitslvstuff f (SubArgsLive SubArgKeep subs) args' args k + splitslvstuff f (SubArgsLive (SubArgOut SubTupRpair{}) subs) (arg':>:args') (arg:>:args) k = error "not SOA'd array" + splitslvstuff (Diagonal f) (SubArgsDead subs) args' (arg@(ArgArray _ r sh _):>:args) k = splitslvstuff (Vertical r f) (SubArgsLive SubArgKeep subs) args' (ArgVar (groundToExpVar (shapeType $ arrayRshape r) sh) :>:args) k + splitslvstuff (IntroO1 f) (SubArgsDead subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroL f) (SubArgsDead lsubs) (arg':>:largs') (arg:>:largs) rsubs rargs' rargs + splitslvstuff (IntroO2 f) (SubArgsDead subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroR f) lsubs largs' largs (SubArgsDead rsubs) (arg':>:rargs') (arg:>:rargs) + splitslvstuff (IntroL f) (SubArgsDead subs) (arg':>:args') (arg:>:args) k = error "out in IntroL/R" + splitslvstuff (IntroR f) (SubArgsDead subs) (arg':>:args') (arg:>:args) k = error "out in IntroL/R" + splitslvstuff (Vertical r f) (SubArgsLive SubArgKeep subs) (ArgVar arg':>:args') (ArgVar arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (Vertical r f) (SubArgsLive SubArgKeep lsubs) (ArgArray Out r sh' buf :>:largs') (ArgArray Out r sh buf :>:largs) (SubArgsLive SubArgKeep rsubs) (ArgArray In r sh' buf :>:rargs') (ArgArray In r sh buf :>:rargs) + where + buf = error "fused away buffer" + sh = expToGroundVar arg + sh' = expToGroundVar arg' + splitslvstuff (Diagonal f) (SubArgsLive SubArgKeep subs) (arg'@(ArgArray Out r' sh' buf'):>:args') (arg@(ArgArray Out r sh buf):>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (Diagonal f) (SubArgsLive SubArgKeep lsubs) (arg':>:largs') (arg:>:largs) (SubArgsLive SubArgKeep rsubs) (ArgArray In r' sh' buf':>:rargs') (ArgArray In r sh buf:>:rargs) + splitslvstuff (Horizontal f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (Horizontal f) (SubArgsLive SubArgKeep lsubs) ( arg':>:largs') ( arg:>:largs) (SubArgsLive SubArgKeep rsubs) ( arg':>:rargs') ( arg:>:rargs) + splitslvstuff (IntroI1 f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroI1 f) (SubArgsLive SubArgKeep lsubs) ( arg':>:largs') ( arg:>:largs) rsubs rargs' rargs + splitslvstuff (IntroI2 f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroI2 f) lsubs largs' largs (SubArgsLive SubArgKeep rsubs) ( arg':>:rargs') ( arg:>:rargs) + splitslvstuff (IntroO1 f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroO1 f) (SubArgsLive SubArgKeep lsubs) ( arg':>:largs') ( arg:>:largs) rsubs rargs' rargs + splitslvstuff (IntroO2 f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroO2 f) lsubs largs' largs (SubArgsLive SubArgKeep rsubs) ( arg':>:rargs') ( arg:>:rargs) + splitslvstuff (IntroL f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroL f) (SubArgsLive SubArgKeep lsubs) ( arg':>:largs') ( arg:>:largs) rsubs rargs' rargs + splitslvstuff (IntroR f) (SubArgsLive SubArgKeep subs) (arg':>:args') (arg:>:args) k = splitslvstuff f subs args' args $ \f lsubs largs' largs rsubs rargs' rargs -> k (IntroR f) lsubs largs' largs (SubArgsLive SubArgKeep rsubs) ( arg':>:rargs') ( arg:>:rargs) + +-- Variant of ShrunkOperation where f is not an existential +data ShrunkOperation' op env f where + ShrunkOperation' :: op f -> Args env f -> ShrunkOperation' op env f -- slvSLVedOp :: SLVedOp op f -> SubArgs f f' -> Args env' f' -> ShrunkOperation' (SLVedOp op) env' f' -- slvSLVedOp (SLVOp op subargs) sub args' = ShrunkOperation' (SLVOp op $ composeSubArgs subargs sub) args' -- instance SLVOperation (SLVedOp op) where -- slvOperation (SLVOp op subargs) = Just $ ShrinkOperation (\sub args' _ -> ShrunkOperation (SLVOp op $ composeSubArgs subargs sub) args') + + + + + + +outvar :: Arg env (Out sh e) -> Arg env (Var' sh) +outvar (ArgArray Out (ArrayR shr _) sh _) = ArgVar $ groundToExpVar (shapeType shr) sh + +varout :: Arg env (Var' sh) -> Arg env (Out sh e) +varout (ArgVar sh) = ArgArray Out (ArrayR (fromJust $ typeShape $ varsType sh) undefined) (expToGroundVar sh) undefined + + +slv :: (forall sh e. f (Out sh e) -> f (Var' sh)) -> SubArgs args args' -> PreArgs f args -> PreArgs f args' +slv _ SubArgsNil ArgsNil = ArgsNil +slv f (SubArgsDead sas) (arg:>:args) = f arg :>: slv f sas args +slv f (SubArgsLive SubArgKeep sas) (arg:>:args) = arg :>: slv f sas args +slv _ _ _ = error "not soa'ed" +slv' :: (forall sh e. f (Var' sh) -> f (Out sh e)) -> SubArgs args args' -> PreArgs f args' -> PreArgs f args +slv' _ SubArgsNil ArgsNil = ArgsNil +slv' f (SubArgsDead sas) (arg:>:args) = f arg :>: slv' f sas args +slv' f (SubArgsLive SubArgKeep sas) (arg:>:args) = arg :>: slv' f sas args +slv' _ _ _ = error "not soa'ed" +slvIn :: (forall sh e. f (Var' sh) -> f (Sh sh e)) -> SubArgs args args' -> Env f (InArgs args') -> Env f (InArgs args) +slvIn _ SubArgsNil Empty = Empty +slvIn f (SubArgsDead sas) (Push env x) = Push (slvIn f sas env) (f x) +slvIn f (SubArgsLive SubArgKeep sas) (Push env x) = Push (slvIn f sas env) x +slvIn _ _ _ = error "not soa'ed" +slvOut :: Args env args' -> SubArgs args args' -> Env f (OutArgs args) -> Env f (OutArgs args') +slvOut _ SubArgsNil Empty = Empty +slvOut (_:>:args) (SubArgsDead sas) (Push env _) = slvOut args sas env +slvOut (a :>: args) (SubArgsLive SubArgKeep sas) env = case a of + ArgArray Out _ _ _ + | Push env' x <- env -> Push (slvOut args sas env') x + ArgArray In _ _ _ -> slvOut args sas env + ArgArray Mut _ _ _ -> slvOut args sas env + ArgVar _ -> slvOut args sas env + ArgFun _ -> slvOut args sas env + ArgExp _ -> slvOut args sas env +slvOut _ _ _ = error "not soa'ed" + + + diff --git a/src/Data/Array/Accelerate/Eval.hs b/src/Data/Array/Accelerate/Eval.hs index 4a6360f85..eaf71aa83 100644 --- a/src/Data/Array/Accelerate/Eval.hs +++ b/src/Data/Array/Accelerate/Eval.hs @@ -70,12 +70,14 @@ class ( MakesILP op => StaticClusterAnalysis (op :: Type -> Type) where data BackendClusterArg2 op env arg onOp :: op args -> BackendArgs op env (OutArgsOf args) -> Args env args -> FEnv op env -> BackendArgs op env args + bcaid :: BackendClusterArg op arg -> BackendClusterArg op arg' def :: Arg env arg -> FEnv op env -> BackendClusterArg op arg -> BackendClusterArg2 op env arg valueToIn :: BackendClusterArg2 op env (Value sh e) -> BackendClusterArg2 op env (In sh e) valueToOut :: BackendClusterArg2 op env (Value sh e) -> BackendClusterArg2 op env (Out sh e) inToValue :: BackendClusterArg2 op env (In sh e) -> BackendClusterArg2 op env (Value sh e) outToValue :: BackendClusterArg2 op env (Out sh e) -> BackendClusterArg2 op env (Value sh e) outToSh :: BackendClusterArg2 op env (Out sh e) -> BackendClusterArg2 op env (Sh sh e) + outToVar :: BackendClusterArg2 op env (Out sh e) -> BackendClusterArg2 op env (Var' sh ) shToOut :: BackendClusterArg2 op env (Sh sh e) -> BackendClusterArg2 op env (Out sh e) shToValue :: BackendClusterArg2 op env (Sh sh e) -> BackendClusterArg2 op env (Value sh e) varToValue :: BackendClusterArg2 op env (Var' sh) -> BackendClusterArg2 op env (Value sh e) @@ -91,18 +93,37 @@ class ( MakesILP op unpairinfo :: BackendClusterArg2 op env (m sh (a,b)) -> (BackendClusterArg2 op env (m sh a), BackendClusterArg2 op env (m sh b)) unpairinfo x = (shrinkOrGrow x, shrinkOrGrow x) - +foo :: StaticClusterAnalysis op => SubArgs big small -> Args env small -> BackendArgs op env (OutArgsOf small) -> FEnv op env -> BackendCluster op small -> BackendArgs op env (OutArgsOf big) +foo SubArgsNil ArgsNil ArgsNil _ _ = ArgsNil +foo (SubArgKeep `SubArgsLive` subargs) (a:>:as) bs env (_ :>: cs) = case a of + ArgArray Out _ _ _ -> case bs of (b:>:bs') -> b :>: foo subargs as bs' env cs + ArgArray In _ _ _ -> foo subargs as bs env cs + ArgArray Mut _ _ _ -> foo subargs as bs env cs + ArgVar _ -> foo subargs as bs env cs + ArgExp _ -> foo subargs as bs env cs + ArgFun _ -> foo subargs as bs env cs +foo (SubArgOut _ `SubArgsLive` subargs) (_ :>: as) (b:>:bs) env (_ :>: cs) = shrinkOrGrow b :>: foo subargs as bs env cs +foo (SubArgsDead subargs) (a :>: as) bs env (c :>: cs) = + shToOut (varToSh (def a env c)) :>: foo subargs as bs env cs makeBackendArg :: forall op env args. StaticClusterAnalysis op => Args env args -> FEnv op env -> Cluster op args -> BackendCluster op args -> BackendArgs op env args -makeBackendArg args env c b = go args c (defaultOuts args b) +makeBackendArg args env c b = go args c (defaultOuts args b) b where - go :: forall args. Args env args -> Cluster op args -> BackendArgs op env (OutArgsOf args) -> BackendArgs op env args - go args (Fused f l r) outputs = let - backR = go (right f args) r (rightB args f outputs) - backL = go (left f args) l (backleft f backR outputs) + go :: forall args. Args env args -> Cluster op args -> BackendArgs op env (OutArgsOf args) -> BackendCluster op args -> BackendArgs op env args + go args (Fused f l r) outputs bs = let + backR = go (right f args) r (rightB args f outputs) (right' bcaid bcaid f bs) + backL = go (left f args) l (backleft f backR outputs) (left' bcaid f bs) in fuseBack f backL backR - go args (Op (SOp (SOAOp op soa) (SA sort unsort)) sa) outputs = - sort . soaExpand uncombineB soa $ onOp @op op (forgetIn (soaShrink combine soa $ unsort args) $ soaShrink combineB soa $ unsort $ inventIn args outputs) (soaShrink combine soa $ unsort args) env + go args (Op (SLV (SOp (SOAOp op soa) (SA sort unsort)) subargs) _l) outputs bs = + slv outToVar subargs + . sort + . soaExpand uncombineB soa + $ onOp @op + op + (forgetIn (soaShrink combine soa . unsort $ slv' varout subargs args) + . soaShrink combineB soa . unsort $ inventIn (slv' varout subargs args) (foo subargs args outputs env bs)) + (soaShrink combine soa . unsort $ slv' varout subargs args) + env combineB :: BackendClusterArg2 op env (f l) -> BackendClusterArg2 op env (f r) -> BackendClusterArg2 op env (f (l,r)) combineB = unsafeCoerce $ pairinfo @op @@ -222,10 +243,12 @@ evalCluster c b args env ix = do evalOps :: forall op args env. (EvalOp op) => Index op -> Cluster op args -> BackendArgEnv op env (InArgs args) -> Args env args -> FEnv op env -> EvalMonad op (EmbedEnv op env (OutArgs args)) evalOps ix c ba args env = case c of - Op (SOp (SOAOp op soas) (SA f g)) l -> outargs f (g args) - . soaOut splitFromArg' (soaShrink combine soas $ g args) soas - <$> evalOp ix l op env (soaIn pairInArg (g args) soas - $ inargs g ba) + Op (SLV (SOp (SOAOp op soas) (SA f g)) subargs) l + -> slvOut args subargs + . outargs f (g $ slv' varout subargs args) + . soaOut splitFromArg' (soaShrink combine soas $ g $ slv' varout subargs args) soas + <$> evalOp ix l op env (soaIn pairInArg (g $ slv' varout subargs args) soas + $ inargs g $ slvIn (flip bvartosh env) subargs ba) Fused f l r -> do lin <- leftIn f ba env lout <- evalOps ix l lin (left f args) env @@ -384,7 +407,7 @@ instance EvalOp op => TupRmonoid (Compose (BackendArgEnvElem op env) (Sh sh)) wh -- use this to check whether a singleton cluster is a generate, map, etc peekSingletonCluster :: (forall args'. op args' -> r) -> Cluster op args -> Maybe r peekSingletonCluster k = \case - Op (SOp (SOAOp op _) _) _ -> Just $ k op + Op (SLV (SOp (SOAOp op _) _) _) _ -> Just $ k op _ -> Nothing -- not a singleton cluster @@ -396,11 +419,11 @@ applySingletonCluster :: forall op env args args' r -> Args env args -> r applySingletonCluster k c args = case c of - Op (SOp (SOAOp op soas) (SA _ unsort)) _ -> + Op (SLV (SOp (SOAOp op soas) (SA _ unsort)) subargs) _ -> unsafeCoerce @(op args' -> Args env args' -> r) @(op _ -> Args env _ -> r) k op - $ soaShrink combine soas $ unsort args + $ soaShrink combine soas $ unsort $ slv' varout subargs args _ -> error "not singleton" @@ -408,15 +431,16 @@ applySingletonCluster k c args = case c of applySingletonCluster' :: forall op env args args' f . (op args' -> Args env args' -> PreArgs f args') -> (forall l r g. f (g (l,r)) -> (f (g l), f (g r))) + -> (forall sh e. f (Out sh e) -> f (Var' sh)) -> Cluster op args -> Args env args -> PreArgs f args -applySingletonCluster' k f c args = case c of - Op (SOp (SOAOp op soas) (SA sort unsort)) _ -> - sort $ soaExpand f soas $ +applySingletonCluster' k f outvar' c args = case c of + Op (SLV (SOp (SOAOp op soas) (SA sort unsort)) subargs) _ -> + slv outvar' subargs $ sort $ soaExpand f soas $ unsafeCoerce @(op args' -> Args env args' -> PreArgs f args') @(op _ -> Args env _ -> PreArgs f _) k op - $ soaShrink combine soas $ unsort args + $ soaShrink combine soas $ unsort $ slv' varout subargs args _ -> error "not singleton" diff --git a/src/Data/Array/Accelerate/Pretty/Partitioned.hs b/src/Data/Array/Accelerate/Pretty/Partitioned.hs index 938d7d3ed..938797bda 100644 --- a/src/Data/Array/Accelerate/Pretty/Partitioned.hs +++ b/src/Data/Array/Accelerate/Pretty/Partitioned.hs @@ -37,6 +37,8 @@ import Prelude hiding (exp) import Data.Array.Accelerate.Representation.Type (TupR (..)) import Data.Array.Accelerate.AST.Idx (Idx (..)) import Data.Bifunctor (second) +import Data.Array.Accelerate.Representation.Array (ArrayR(..)) +import Data.Array.Accelerate.AST.Var (varsType) instance PrettyOp op => PrettyOp (Clustered op) where prettyOp :: PrettyOp op => Clustered op t -> Adoc @@ -45,9 +47,10 @@ instance PrettyOp op => PrettyOp (Clustered op) where instance PrettyOp op => PrettyOp (Cluster op) where prettyOp (Fused _ l r) = "Fused (" <> prettyOp l <> ", " <> prettyOp r - prettyOp (Op (SOp (SOAOp op _) _) _) = prettyOp op - prettyOpWithArgs env (Fused f l r) args = "Fused (" <> prettyOpWithArgs env l (left f args) <> ", " <> prettyOpWithArgs env r (right f args) - prettyOpWithArgs env (Op (SOp (SOAOp op soa) (SA _ unsort)) _) args = prettyOpWithArgs env op (soaShrink combine soa . unsort $ args) + prettyOp (Op (SLV (SOp (SOAOp op _) _) _) _) = prettyOp op + prettyOpWithArgs env (Fused f l r) args = "Fused" -- (" <> prettyOpWithArgs env l (left f args) <> ", " <> prettyOpWithArgs env r (right f args) + prettyOpWithArgs env (Op (SLV (SOp (SOAOp op soa) (SA _ unsort)) subargs) _) args = + prettyOpWithArgs env op (soaShrink combine soa . unsort . slv' varout subargs $ args) -- prettyOpWithArgs :: forall env t. Val env -> Cluster op t -> Args env t -> Adoc -- prettyOpWithArgs env (Op (SLVOp (SOp (SOAOp op soa) (SA _ unsort)) sa) _) args = prettyOpWithArgs env op (soaShrink combine soa . unsort . slv' varToOut sa $ args) diff --git a/src/Data/Array/Accelerate/Trafo/Operation/LiveVars.hs b/src/Data/Array/Accelerate/Trafo/Operation/LiveVars.hs index 943e74ac0..193a85f75 100644 --- a/src/Data/Array/Accelerate/Trafo/Operation/LiveVars.hs +++ b/src/Data/Array/Accelerate/Trafo/Operation/LiveVars.hs @@ -230,6 +230,7 @@ class SLVOperation op where newtype ShrinkOperation op f = ShrinkOperation (forall f' env' env. SubArgs f f' -> Args env' f' -> Args env f -> ShrunkOperation op env') +-- existential over f: otherwise, you couldn't change the non-array arguments. You need this e.g. for a Generate: smaller array means smaller function. data ShrunkOperation op env where ShrunkOperation :: op f -> Args env f -> ShrunkOperation op env diff --git a/src/Data/Array/Accelerate/Trafo/Partitioning/ILP/Clustering.hs b/src/Data/Array/Accelerate/Trafo/Partitioning/ILP/Clustering.hs index cb38dee4a..1ac9177f7 100644 --- a/src/Data/Array/Accelerate/Trafo/Partitioning/ILP/Clustering.hs +++ b/src/Data/Array/Accelerate/Trafo/Partitioning/ILP/Clustering.hs @@ -20,6 +20,7 @@ -Wno-overlapping-patterns -Wno-incomplete-patterns #-} +{-# LANGUAGE BlockArguments #-} module Data.Array.Accelerate.Trafo.Partitioning.ILP.Clustering where @@ -314,10 +315,11 @@ data FoldType op env unfused :: forall op args env r. MakesILP op => op args -> Label -> LabelledArgsOp op env args -> (forall args'. Clustered op args' -> LabelledArgsOp op env args' -> r) -> r -unfused op l largs k = singleton l largs op $ - \c@(Clustered (Op (SOp (SOAOp (_op :: op argsToo) soas) (SA sort _unsort)) _) b) -> +unfused op l largs k = singleton l largs op \case + c@(Clustered (Op (SLV (SOp (SOAOp (_op :: op argsToo) soas) (SA sort _unsort)) subargs) _l) _b) -> case unsafeCoerce Refl of -- we know that `_op` is the same as `op` - (Refl :: args :~: argsToo) -> k c (sort $ soaExpand splitLabelledArgsOp soas largs) + (Refl :: args :~: argsToo) -> k c (slv louttovar subargs $ sort $ soaExpand splitLabelledArgsOp soas largs) + _ -> error "singleton gave fused" louttovar :: LabelledArgOp op env (Out sh e) -> LabelledArgOp op env (Var' sh) louttovar (LOp a (_,ls) b) = LOp (outvar a) (NotArr, ls) b -- unsafe marker: maybe this NotArr ends up a problem? diff --git a/src/Data/Array/Accelerate/Trafo/Partitioning/ILP/Graph.hs b/src/Data/Array/Accelerate/Trafo/Partitioning/ILP/Graph.hs index b7b4c94d7..6fa606158 100644 --- a/src/Data/Array/Accelerate/Trafo/Partitioning/ILP/Graph.hs +++ b/src/Data/Array/Accelerate/Trafo/Partitioning/ILP/Graph.hs @@ -32,6 +32,8 @@ import Data.Array.Accelerate.Trafo.Partitioning.ILP.Solver import Data.Array.Accelerate.Type import Data.Array.Accelerate.Analysis.Hash.Exp +import Data.Array.Accelerate.Trafo.Operation.LiveVars + -- Data structures -- In this file, order often subly matters. -- To keep this clear, we use Set whenever it does not, @@ -160,7 +162,7 @@ unOpLabels = mapArgs $ \(LOp arg l _) -> L arg l type BackendCluster op = PreArgs (BackendClusterArg op) -class (Eq (BackendVar op), Ord (BackendVar op), Eq (BackendArg op), Show (BackendArg op), Ord (BackendArg op), Show (BackendVar op)) => MakesILP op where +class (ShrinkArg (BackendClusterArg op), Eq (BackendVar op), Ord (BackendVar op), Eq (BackendArg op), Show (BackendArg op), Ord (BackendArg op), Show (BackendVar op)) => MakesILP op where -- Vars needed to express backend-specific fusion rules. type BackendVar op -- Information that the backend attaches to the argument for reconstruction,