From 6339274ea22ef472a6d75c9e5403d190b31f145e Mon Sep 17 00:00:00 2001 From: David Date: Fri, 6 Oct 2023 12:33:31 +0200 Subject: [PATCH] up to date with cluster representation refactor --- .../Array/Accelerate/LLVM/Native/CodeGen.hs | 86 +++++++++---------- .../Array/Accelerate/LLVM/Native/Operation.hs | 10 +-- accelerate-llvm-native/test/nofib/Main.hs | 9 +- stack.yaml | 12 +-- 4 files changed, 57 insertions(+), 60 deletions(-) diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs index 9237f8181..e3f941657 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs @@ -37,7 +37,7 @@ import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Representation.Shape (shapeType, ShapeR(..), rank) import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.AST.Exp -import Data.Array.Accelerate.AST.Partitioned as P +import Data.Array.Accelerate.AST.Partitioned as P hiding (combine) import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Eval import Data.Array.Accelerate.Type @@ -64,7 +64,7 @@ import Data.Array.Accelerate.Trafo.Desugar (ArrayDescriptor(..)) import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic ( when, ifThenElse, just, add, lt, mul, eq, fromJust, isJust ) import Data.Array.Accelerate.Analysis.Match (matchShapeR) import Data.Array.Accelerate.Trafo.Exp.Substitution (compose) -import Data.Array.Accelerate.AST.Operation (groundToExpVar, Fun) +import Data.Array.Accelerate.AST.Operation (groundToExpVar, Fun, mapArgs) import Data.Array.Accelerate.LLVM.Native.CodeGen.Permute (atomically) import Control.Monad.State (StateT(..), lift, evalStateT, execStateT) import qualified Data.Map as M @@ -80,37 +80,40 @@ import Data.Array.Accelerate.LLVM.CodeGen.Loop (iter, while) import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop import qualified Debug.Trace import Formatting.ShortFormatters (o) +import Data.Array.Accelerate (SLVOperation) +import Data.Array.Accelerate.Backend (SLVOperation(..)) codegen :: UID -> Env AccessGroundR env - -> Cluster NativeOp args + -> Clustered NativeOp args -> Args env args -> LLVM Native (Module (KernelType env)) -codegen uid env c@(Cluster _ (Cluster' cIO cAST)) args = +codegen uid env (Clustered c b) args = codeGenFunction uid "fused_cluster_name" (LLVM.Lam argTp "arg") $ do extractEnv workstealLoop workstealIndex workstealActiveThreads (op scalarTypeInt32 $ constant (TupRsingle scalarTypeInt32) 1) $ \_ -> do - (acc, loopsize') <- execStateT (evalCluster (toOnlyAcc c) args gamma ()) (mempty, LS ShapeRz OP_Unit) + let b' = mapArgs BCAJA b + (acc, loopsize') <- execStateT (evalCluster (toOnlyAcc c) b' args gamma ()) (mempty, LS ShapeRz OP_Unit) body acc loopsize' retval_ $ boolean True where (argTp, extractEnv, workstealIndex, workstealActiveThreads, gamma) = bindHeaderEnv env body :: Accumulated -> Loopsizes -> CodeGen Native () body initialAcc partialLoopSize = - case combine (loopsizeOutVertical c gamma args) partialLoopSize of + case partialLoopSize of -- used to combine with loopSize here, but I think we can do everything in the static analysis? LS shr' sh' -> let go :: ShapeR sh -> Operands sh -> (Int, Operands Int, [Operands Int]) -> StateT Accumulated (CodeGen Native) () go ShapeRz OP_Unit _ = pure () go (ShapeRsnoc shr) (OP_Pair sh sz) ix = iter' sz ix $ \i@(d,lin,is) -> do recurLin <- lift $ mul numType lin (firstOrZero shr sh) go shr sh (d+1, recurLin, is) - let ba = makeBackendArg @NativeOp args gamma c - newInputs <- evalI @NativeOp i cIO args ba gamma - outputs <- evalAST @NativeOp i cAST gamma newInputs - evalO @NativeOp i cIO args gamma outputs + let ba = makeBackendArg @NativeOp args gamma c b + newInputs <- readInputs @_ @_ @NativeOp i args ba gamma + outputs <- evalOps @NativeOp i c newInputs args gamma + writeOutputs @_ @_ @NativeOp i args outputs gamma in case (shr', flipShape shr' sh') of -- (ShapeRz,_) -> error "tiny cluster" -- (ShapeRsnoc shr', OP_Pair sh' sz) -> @@ -174,19 +177,20 @@ flipShape shr = multidim shr . reverse . multidim' shr -- TODO: we need to only consider each _in-order_ vertical argument -- TODO: we ignore backpermute currently. Could use this function to check the outputs and vertical, and the staticclusteranalysis evalI for the inputs. -- e.g. backpermute . fold: shape of backpermute output plus the extra dimension of fold. -loopsizeOutVertical :: forall args env. Cluster NativeOp args -> Gamma env -> Args env args -> Loopsizes -loopsizeOutVertical (Cluster _ (Cluster' io _)) gamma args = go io args - where - go :: ClusterIO a i o -> Args env a -> Loopsizes - go Empty ArgsNil = LS ShapeRz OP_Unit - go (Input io') (ArgArray _ (ArrayR shr _) sh _ :>: args') = go io' args' -- $ \x -> combine x (shr, aprjParameters (unsafeToExpVars sh) gamma) k - go (Output _ _ _ io') (ArgArray _ (ArrayR shr _) sh _ :>: args') = combine (go io' args') $ LS shr (aprjParameters (unsafeToExpVars sh) gamma) - go (Trivial io') (ArgArray _ (ArrayR shr _) sh _ :>: args') = combine (go io' args') $ LS shr (aprjParameters (unsafeToExpVars sh) gamma) - go (Vertical _ (ArrayR shr _) io') (ArgVar sh :>: args') = combine (go io' args') $ LS shr (aprjParameters sh gamma) - go (MutPut io') (_ :>: args') = go io' args' - go (ExpPut io') (_ :>: args') = go io' args' - go (VarPut io') (_ :>: args') = go io' args' - go (FunPut io') (_ :>: args') = go io' args' +-- loopsizeOutVertical :: forall args env. Cluster NativeOp args -> Gamma env -> Args env args -> Loopsizes +-- loopsizeOutVertical = undefined +-- loopsizeOutVertical (Cluster _ (Cluster' io _)) gamma args = go io args +-- where +-- go :: ClusterIO a i o -> Args env a -> Loopsizes +-- go Empty ArgsNil = LS ShapeRz OP_Unit +-- go (Input io') (ArgArray _ (ArrayR shr _) sh _ :>: args') = go io' args' -- $ \x -> combine x (shr, aprjParameters (unsafeToExpVars sh) gamma) k +-- go (Output _ _ _ io') (ArgArray _ (ArrayR shr _) sh _ :>: args') = combine (go io' args') $ LS shr (aprjParameters (unsafeToExpVars sh) gamma) +-- go (Trivial io') (ArgArray _ (ArrayR shr _) sh _ :>: args') = combine (go io' args') $ LS shr (aprjParameters (unsafeToExpVars sh) gamma) +-- go (Vertical _ (ArrayR shr _) io') (ArgVar sh :>: args') = combine (go io' args') $ LS shr (aprjParameters sh gamma) +-- go (MutPut io') (_ :>: args') = go io' args' +-- go (ExpPut io') (_ :>: args') = go io' args' +-- go (VarPut io') (_ :>: args') = go io' args' +-- go (FunPut io') (_ :>: args') = go io' args' -- get the largest ShapeR, and the corresponding shape combine :: Loopsizes -> Loopsizes -> Loopsizes combine x@(LS a _) y@(LS b _) = if rank a > rank b then x else y @@ -203,6 +207,8 @@ instance EvalOp NativeOp where type Embed' NativeOp = Compose Maybe Operands type EnvF NativeOp = GroundOperand + unit = Compose $ Just OP_Unit + -- don't need to be in the monad! indexsh vars gamma = pure . CJ $ aprjParameters (unsafeToExpVars vars) gamma indexsh' vars gamma = pure . CJ $ aprjParameters vars gamma @@ -321,18 +327,10 @@ multidim' (ShapeRsnoc shr) (OP_Pair sh i) = i : multidim' shr sh instance TupRmonoid Operands where pair' = OP_Pair unpair' (OP_Pair x y) = (x, y) - injL x p = OP_Pair x (proofToOp p) - injR x p = OP_Pair (proofToOp p) x instance (TupRmonoid f) => TupRmonoid (Compose Maybe f) where pair' (Compose l) (Compose r) = Compose $ pair' <$> l <*> r unpair' (Compose p) = maybe (CN, CN) (bimap CJ CJ . unpair') p - injL (Compose x) p = Compose $ (`injL` p) <$> x - injR (Compose x) p = Compose $ (`injR` p) <$> x - -proofToOp :: TupUnitsProof a -> Operands a -proofToOp OneUnit = OP_Unit -proofToOp (MoreUnits x y) = OP_Pair (proofToOp x) (proofToOp y) unsafeToExpVars :: GroundVars env sh -> ExpVars env sh unsafeToExpVars TupRunit = TupRunit @@ -341,6 +339,8 @@ unsafeToExpVars (TupRsingle (Var g idx)) = case g of GroundRbuffer _ -> error "unsafeToExpVars on a buffer" GroundRscalar t -> TupRsingle (Var t idx) +instance SLVOperation NativeOp where + slvOperation = const Nothing maybeTy :: TypeR a -> TypeR (PrimMaybe a) maybeTy ty = TupRpair (TupRsingle scalarTypeWord8) (TupRpair TupRunit ty) @@ -360,7 +360,7 @@ data Loopsizes where LS :: ShapeR sh -> Operands sh -> Loopsizes merge :: Loopsizes -> GroundVars env sh -> Gamma env -> Loopsizes -merge ls v gamma = combine ls $ LS (varsToShapeR v) (aprjParameters (unsafeToExpVars v) gamma) +merge ls v gamma = combine ls $ LS (gvarsToShapeR v) (aprjParameters (unsafeToExpVars v) gamma) -- mkls :: Int -> ShapeR sh -> Operands sh -> Bool -> Loopsizes -- mkls d shr sh b @@ -393,6 +393,8 @@ instance EvalOp (JustAccumulator NativeOp) where type Embed' (JustAccumulator NativeOp) = TypeR type EnvF (JustAccumulator NativeOp) = GroundOperand + unit = TupRunit + indexsh vars _ = pure $ mapTupR varType $ unsafeToExpVars vars indexsh' vars _ = pure $ mapTupR varType vars @@ -466,14 +468,8 @@ deriving instance (Eq (BackendClusterArg2 op x y)) => Eq (BackendClusterArg2 (Ju toOnlyAcc :: Cluster op args -> Cluster (JustAccumulator op) args -toOnlyAcc (Cluster bc (Cluster' io ast)) = Cluster (go2 bc) (Cluster' io $ go ast) - where - go :: ClusterAST op env args' -> ClusterAST (JustAccumulator op) env args' - go P.None = P.None - go (Bind lhs op l ast') = Bind lhs (JA op) l (go ast') - go2 :: BackendCluster op args -> BackendCluster (JustAccumulator op) args - go2 ArgsNil = ArgsNil - go2 (bca :>: args) = BCAJA bca :>: go2 args +toOnlyAcc (Fused f l r) = Fused f (toOnlyAcc l) (toOnlyAcc r) +toOnlyAcc (Op (SLVOp (SOp (SOAOp op soa) sort) sa)) = Op (SLVOp (SOp (SOAOp (JA op) soa) sort) sa) pattern CJ :: f a -> Compose Maybe f a pattern CJ x = Compose (Just x) @@ -523,9 +519,9 @@ zeroes ty@(TupRsingle t) = case t of TypeFloat -> constant ty 0 TypeDouble -> constant ty 0 -varsToShapeR :: Vars GroundR x sh -> ShapeR sh -varsToShapeR TupRunit = ShapeRz -varsToShapeR (TupRpair sh (TupRsingle x)) = case x of - Var (GroundRscalar (SingleScalarType (NumSingleType (IntegralNumType TypeInt)))) _ -> ShapeRsnoc $ varsToShapeR sh +gvarsToShapeR :: Vars GroundR x sh -> ShapeR sh +gvarsToShapeR TupRunit = ShapeRz +gvarsToShapeR (TupRpair sh (TupRsingle x)) = case x of + Var (GroundRscalar (SingleScalarType (NumSingleType (IntegralNumType TypeInt)))) _ -> ShapeRsnoc $ gvarsToShapeR sh _ -> error "not a shape" -varsToShapeR _ = error "not a shape" +gvarsToShapeR _ = error "not a shape" diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs index bb9331d01..1a9faeddf 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs @@ -169,11 +169,11 @@ instance SimplifyOperation NativeOp where detectCopy matchVars' NBackpermute = detectBackpermuteCopies matchVars' detectCopy _ _ = const [] -instance SLVOperation NativeOp where - slvOperation NGenerate = defaultSlvGenerate NGenerate - slvOperation NMap = defaultSlvMap NMap - slvOperation NBackpermute = defaultSlvBackpermute NBackpermute - slvOperation _ = Nothing +-- instance SLVOperation NativeOp where +-- slvOperation NGenerate = defaultSlvGenerate NGenerate +-- slvOperation NMap = defaultSlvMap NMap +-- slvOperation NBackpermute = defaultSlvBackpermute NBackpermute +-- slvOperation _ = Nothing instance EncodeOperation NativeOp where encodeOperation NMap = intHost $(hashQ ("Map" :: String)) diff --git a/accelerate-llvm-native/test/nofib/Main.hs b/accelerate-llvm-native/test/nofib/Main.hs index 9e2bde507..52c8de506 100644 --- a/accelerate-llvm-native/test/nofib/Main.hs +++ b/accelerate-llvm-native/test/nofib/Main.hs @@ -31,11 +31,12 @@ import Data.Array.Accelerate.Unsafe main :: IO () main = do + Prelude.print $ runN @Interpreter complex $ fromList (Z:.100) $ Prelude.map (`Prelude.mod` 50) [1 :: Int ..] -- benchmarking: - defaultMain $ - Prelude.map (benchOption . Prelude.Left) [minBound :: Objective .. maxBound] - Prelude.++ - Prelude.map (benchOption . Prelude.Right) [NoFusion, GreedyFusion] + -- defaultMain $ + -- Prelude.map (benchOption . Prelude.Left) [minBound :: Objective .. maxBound] + -- Prelude.++ + -- Prelude.map (benchOption . Prelude.Right) [NoFusion, GreedyFusion] -- Prelude.print $ runNWithObj @Native ArrayReadsWrites $ quicksort $ use $ fromList (Z :. 5) [100::Int, 200, 3, 5, 4] where diff --git a/stack.yaml b/stack.yaml index 4c3c58567..c82f7ffd7 100644 --- a/stack.yaml +++ b/stack.yaml @@ -2,7 +2,7 @@ # For advanced use and comprehensive documentation of the format, please see: # https://docs.haskellstack.org/en/stable/yaml_configuration/ -resolver: lts-21.7 +resolver: lts-21.12 packages: - accelerate-llvm @@ -15,11 +15,11 @@ extra-deps: - OptDir-0.0.4 - bytestring-encoding-0.1.2.0 - ../accelerate - -- github: msakai/haskell-MIP - commit: 4295aa21a24a30926b55770c55ac00f749fb8a39 - subdirs: - - MIP +- MIP-0.1.1.0 +# - github: msakai/haskell-MIP +# commit: 4295aa21a24a30926b55770c55ac00f749fb8a39 +# subdirs: +# - MIP - github: llvm-hs/llvm-hs commit: e4b3cfa47e72f094ab109884f18acfc666b0fb7d # llvm-15