From 165a9e787fd09450f54474956b3b0e553f30fee9 Mon Sep 17 00:00:00 2001 From: David van Balen Date: Thu, 27 Jun 2024 16:28:20 +0200 Subject: [PATCH] cleanup --- accelerate-llvm-native/greediesarebad.csv | 1 + .../src/Data/Array/Accelerate/LLVM/Native.hs | 7 +- .../Array/Accelerate/LLVM/Native/CodeGen.hs | 135 ++++++------------ .../Array/Accelerate/LLVM/Native/Operation.hs | 8 +- 4 files changed, 46 insertions(+), 105 deletions(-) diff --git a/accelerate-llvm-native/greediesarebad.csv b/accelerate-llvm-native/greediesarebad.csv index a214388d0..12f1b97fd 100644 --- a/accelerate-llvm-native/greediesarebad.csv +++ b/accelerate-llvm-native/greediesarebad.csv @@ -8,3 +8,4 @@ GreedyDown/backwardbad,0.708492649517366,0.6981189270352098,0.7464575397676662,5 NoFusion/forwardbad,9.706273619993701e-2,9.441343113486038e-2,0.10235459452755384,1.469075659904803e-2,8.002086943955398e-3,2.5279871688187384e-2 NoFusion/backwardbad,1.341660145146297,1.304144801615958,1.3596349875245755,5.49257811758058e-2,1.3556280370411532e-2,8.886663759599717e-2 Name,Mean,MeanLB,MeanUB,Stddev,StddevLB,StddevUB +Name,Mean,MeanLB,MeanUB,Stddev,StddevLB,StddevUB diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native.hs index 38fe10f8b..b61ab7316 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native.hs @@ -1,15 +1,10 @@ -{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeSynonymInstances #-} + -- | -- Module : Data.Array.Accelerate.LLVM.Native -- Copyright : [2014..2020] The Accelerate Team diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs index d597ff595..d68421689 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs @@ -10,12 +10,12 @@ {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -Wno-name-shadowing #-} -- | -- Module : Data.Array.Accelerate.LLVM.Native.CodeGen @@ -40,16 +40,12 @@ import Data.Array.Accelerate.Representation.Shape (shapeType, ShapeR(..), rank) import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.AST.Exp import Data.Array.Accelerate.AST.Partitioned as P hiding (combine) -import Data.Array.Accelerate.AST.Var import Data.Array.Accelerate.Eval import Data.Array.Accelerate.Type -import Data.Array.Accelerate.Trafo.LiveVars import qualified Data.Array.Accelerate.AST.Environment as Env import Data.Array.Accelerate.LLVM.State -import Data.Array.Accelerate.LLVM.Compile.Cache import Data.Array.Accelerate.LLVM.CodeGen.Environment hiding ( Empty ) import Data.Array.Accelerate.LLVM.Native.Operation --- import Data.Array.Accelerate.LLVM.Native.CodeGen.Skeleton import Data.Array.Accelerate.LLVM.Native.CodeGen.Base import Data.Array.Accelerate.LLVM.Native.Target import Data.Typeable @@ -58,55 +54,44 @@ import LLVM.AST.Type.Module import Data.Array.Accelerate.LLVM.CodeGen.Monad import qualified LLVM.AST.Type.Function as LLVM import Data.Array.Accelerate.LLVM.CodeGen.Array -import Data.Array.Accelerate.LLVM.CodeGen.IR (Operands (..), IROP (..)) import Unsafe.Coerce (unsafeCoerce) import Data.Array.Accelerate.LLVM.CodeGen.Sugar (app1, IROpenFun2 (app2)) import Data.Array.Accelerate.LLVM.CodeGen.Exp (llvmOfFun1, intOfIndex, llvmOfFun2) import Data.Array.Accelerate.Trafo.Desugar (ArrayDescriptor(..)) -import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic ( when, ifThenElse, just, add, lt, mul, eq, fromJust, isJust, liftInt ) +import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic ( when, ifThenElse, eq, fromJust, isJust, liftInt ) import Data.Array.Accelerate.Analysis.Match (matchShapeR) import Data.Array.Accelerate.Trafo.Exp.Substitution (compose) -import Data.Array.Accelerate.AST.Operation (groundToExpVar, Fun, mapArgs) import Data.Array.Accelerate.LLVM.Native.CodeGen.Permute (atomically) -import Control.Monad.State (StateT(..), lift, evalStateT, execStateT) +import Control.Monad.State (StateT(..), lift, execStateT) import qualified Data.Map as M import Data.ByteString.Short ( ShortByteString ) import Data.Array.Accelerate.AST.LeftHandSide (Exists (Exists), lhsToTupR) import Data.Array.Accelerate.Trafo.Partitioning.ILP.Labels (Label) -import Data.Array.Accelerate.LLVM.CodeGen.Constant (constant, boolean) -import Data.Array.Accelerate.Trafo.Partitioning.ILP.Graph (MakesILP(..), BackendCluster) +import Data.Array.Accelerate.LLVM.CodeGen.Constant (constant) +import Data.Array.Accelerate.Trafo.Partitioning.ILP.Graph (MakesILP(..)) import Data.Coerce (coerce) import Data.Functor.Compose import qualified Control.Monad as Prelude import Data.Bifunctor (Bifunctor(..)) -import Data.Array.Accelerate.LLVM.CodeGen.Loop (iter, while) import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop -import qualified Debug.Trace -import Formatting.ShortFormatters (o) -import Data.Array.Accelerate (SLVOperation) -import Data.Array.Accelerate.Backend (SLVOperation(..)) import Data.Array.Accelerate.LLVM.CodeGen.IR - -traceIfDebugging str a = a --Debug.Trace.trace str a +traceIfDebugging :: String -> a -> a +traceIfDebugging _ a = a --Debug.Trace.trace str a codegen :: ShortByteString -> Env AccessGroundR env -> Clustered NativeOp args -> Args env args -> LLVM Native (Module (KernelType env)) -codegen name env (Clustered c b) args = +codegen name env (Clustered c b) args = codeGenFunction name (LLVM.Lam argTp "arg") $ do - -- putchar $ liftInt 73 extractEnv - -- putchar $ liftInt 74 - -- workstealLoop workstealIndex workstealActiveThreads (op scalarTypeInt32 $ constant (TupRsingle scalarTypeInt32) 1) $ \_ -> do let b' = mapArgs BCAJA b (acc, loopsize) <- execStateT (evalCluster (toOnlyAcc c) b' args gamma ()) (mempty, LS ShapeRz OP_Unit) - -- body acc loopsize' - acc' <- operandsMapToPairs acc $ \(accTypeR, toOp, fromOp) -> fmap fromOp $ flip execStateT (toOp acc) $ case loopsize of - LS loopshr loopsh -> -- Debug.Trace.traceShow (rank loopshr) $ - workstealChunked loopshr workstealIndex workstealActiveThreads (flipShape loopshr loopsh) accTypeR + _acc' <- operandsMapToPairs acc $ \(accTypeR, toOp, fromOp) -> fmap fromOp $ flip execStateT (toOp acc) $ case loopsize of + LS loopshr loopsh -> + workstealChunked loopshr workstealIndex workstealActiveThreads (flipShape loopshr loopsh) accTypeR (body loopshr toOp fromOp, -- the LoopWork StateT $ \op -> second toOp <$> runStateT (foo (liftInt 0) []) (fromOp op)) -- the action to run after the outer loop -- acc'' <- flip execStateT acc' $ foo (liftInt 0) [] @@ -125,47 +110,6 @@ codegen name env (Clustered c b) args = outputs <- evalOps @NativeOp i c newInputs args gamma writeOutputs @_ @_ @NativeOp i args outputs gamma - body' :: Accumulated -> Loopsizes -> CodeGen Native () - body' initialAcc partialLoopSize = - case partialLoopSize of -- used to combine with loopSize here, but I think we can do everything in the static analysis? - LS shr' sh' -> - let go :: ShapeR sh -> Operands sh -> (Int, Operands Int, [Operands Int]) -> StateT Accumulated (CodeGen Native) () - go ShapeRz OP_Unit _ = pure () - go (ShapeRsnoc shr) (OP_Pair sh sz) ix = iter' sz ix $ \i@(d,lin,is) -> do - recurLin <- lift $ mul numType lin (firstOrZero shr sh) - go shr sh (d+1, recurLin, is) - let ba = makeBackendArg @NativeOp args gamma c b - newInputs <- readInputs @_ @_ @NativeOp i args ba gamma - outputs <- evalOps @NativeOp i c newInputs args gamma - writeOutputs @_ @_ @NativeOp i args outputs gamma - in case (shr', flipShape shr' sh') of - -- (ShapeRz,_) -> error "tiny cluster" - -- (ShapeRsnoc shr', OP_Pair sh' sz) -> - (shr', sh') -> - flip evalStateT initialAcc $ do - -- we add one more layer here, for writing scalars -- e.g. the output of a fold over a vector - go shr' sh' (1, constant typerInt 0, []) - let ba = makeBackendArg @NativeOp args gamma c b - let i = (0, constant typerInt 0 ,[]) - newInputs <- readInputs @_ @_ @NativeOp i args ba gamma - outputs <- evalOps @NativeOp i c newInputs args gamma - writeOutputs @_ @_ @NativeOp i args outputs gamma - - iter' :: Operands Int - -> (Int, Operands Int, [Operands Int]) - -> ((Int, Operands Int, [Operands Int]) -> StateT Accumulated (CodeGen Native) ()) -> StateT Accumulated (CodeGen Native) () - iter' size (d, linearI, outerI) body = StateT $ \accMap -> - operandsMapToPairs accMap $ \(accR, toR, fromR) -> - ((),) . fromR <$> iter - (TupRpair typerInt typerInt) - accR - (OP_Pair (constant typerInt 0) linearI) - (toR accMap) - (\(OP_Pair i _) -> lt singleType i size) - (\(OP_Pair i l) -> OP_Pair <$> add numType (constant typerInt 1) i <*> add numType (constant typerInt 1) l) - (\(OP_Pair i l) -> fmap (toR . snd) . runStateT (body (d, l, i:outerI)) . fromR) - - -- We use some unsafe coerces in the context of the accumulators. @@ -223,7 +167,7 @@ operandsMapToPairs acc k -- get the largest ShapeR, and the corresponding shape combine :: Loopsizes -> Loopsizes -> Loopsizes combine x@(LS a _) y@(LS b _) = if rank a > rank b then x else y - + type Accumulated = M.Map Label (Exists Operands, Exists TypeR) @@ -237,7 +181,7 @@ instance EvalOp NativeOp where type EnvF NativeOp = GroundOperand embed (GroundOperandParam x) = Compose $ Just $ ir' x - embed (GroundOperandBuffer x) = error "does this ever happen?" + embed (GroundOperandBuffer _) = error "does this ever happen?" unit = Compose $ Just OP_Unit @@ -257,10 +201,10 @@ instance EvalOp NativeOp where writeOutput _ _ _ _ _ = error "not single" readInput :: forall e env sh. ScalarType e -> GroundVars env sh -> GroundVars env (Buffers e) -> Gamma env -> BackendClusterArg2 NativeOp env (In sh e) -> (Int, Operands Int, [Operands Int]) -> StateT Accumulated (CodeGen Native) (Compose Maybe Operands e) readInput _ _ _ _ _ (d,_,is) | d /= length is = error "fail" - readInput tp sh (TupRsingle buf) gamma (BCAN2 Nothing d') (d,i, _) + readInput tp _ (TupRsingle buf) gamma (BCAN2 Nothing d') (d,i, _) | d /= d' = pure CN | otherwise = lift $ CJ . ir tp <$> readBuffer tp TypeInt (aprjBuffer (unsafeCoerce buf) gamma) (op TypeInt i) - readInput tp sh (TupRsingle buf) gamma (BCAN2 (Just (BP shr1 (shr2 :: ShapeR sh2) f ls)) _) (d,_,ix) + readInput tp sh (TupRsingle buf) gamma (BCAN2 (Just (BP shr1 (shr2 :: ShapeR sh2) f _ls)) _) (d,_,ix) | Just Refl <- varsContainsThisShape sh shr2 , shr1 `isAtDepth'` d = lift $ CJ . ir tp <$> do @@ -270,7 +214,7 @@ instance EvalOp NativeOp where i <- intOfIndex shr2 sh' sh2 readBuffer tp TypeInt (aprjBuffer (unsafeCoerce buf) gamma) (op TypeInt i) | otherwise = pure CN - readInput tp _ (TupRsingle buf) gamma a (_,i,_) = error "here" + readInput _ _ (TupRsingle _) _ _ (_,_,_) = error "here" -- assuming no bp, and I'll just make a read at every depth? -- lift $ CJ . ir tp <$> readBuffer tp TypeInt (aprjBuffer (unsafeCoerce buf) gamma) (op TypeInt i) -- second attempt, the above segfaults: never read instead @@ -292,17 +236,17 @@ instance EvalOp NativeOp where -- evalOp _ _ _ _ _ = error "todo: add depth checks to all matches" evalOp (d,_,_) _ NMap gamma (Push (Push (Push Env.Empty (BAE _ _)) (BAE (Value' x' (Shape' shr sh)) (BCAN2 _ d'))) (BAE f _)) = lift $ Push Env.Empty . FromArg . flip Value' (Shape' shr sh) <$> case x' of - CJ x | d == d' -> CJ <$> (traceIfDebugging ("map" <> show d') $ app1 (llvmOfFun1 @Native f gamma) x) + CJ x | d == d' -> CJ <$> traceIfDebugging ("map" <> show d') (app1 (llvmOfFun1 @Native f gamma) x) _ -> pure CN evalOp _ _ NBackpermute _ (Push (Push (Push Env.Empty (BAE (Shape' shr sh) _)) (BAE (Value' x _) _)) _) = lift $ pure $ Push Env.Empty $ FromArg $ Value' x $ Shape' shr sh - evalOp (d',_,is) _ NGenerate gamma (Push (Push Env.Empty (BAE (Shape' shr (CJ sh)) _)) (BAE f (BCAN2 Nothing d))) + evalOp (d',_,is) _ NGenerate gamma (Push (Push Env.Empty (BAE (Shape' shr (CJ sh)) _)) (BAE f (BCAN2 Nothing _d))) | shr `isAtDepth'` d' = traceIfDebugging ("generate" <> show d') $ lift $ Push Env.Empty . FromArg . flip Value' (Shape' shr (CJ sh)) . CJ <$> app1 (llvmOfFun1 @Native f gamma) (multidim shr is) -- | d' == d = error $ "how come we didn't hit the case above?" <> show (d', d, rank shr) | otherwise = pure $ Push Env.Empty $ FromArg $ Value' CN (Shape' shr (CJ sh)) - evalOp (d',_,is) _ NGenerate gamma (Push (Push Env.Empty (BAE (Shape' shr sh) _)) (BAE f (BCAN2 (Just (BP shrO shrI idxTransform ls)) d))) + evalOp (d',_,is) _ NGenerate gamma (Push (Push Env.Empty (BAE (Shape' shr sh) _)) (BAE f (BCAN2 (Just (BP shrO shrI idxTransform _ls)) _d))) | not $ shrO `isAtDepth'` d' - = pure $ Push Env.Empty $ FromArg $ flip Value' (Shape' shr sh) CN + = pure $ Push Env.Empty $ FromArg $ Value' CN (Shape' shr sh) | Just Refl <- matchShapeR shrI shr = traceIfDebugging ("generate" <> show d') $ lift $ Push Env.Empty . FromArg . flip Value' (Shape' shr sh) . CJ <$> app1 (llvmOfFun1 @Native (compose f idxTransform) gamma) (multidim shrO is) | otherwise = error "bp shapeR doesn't match generate's output" @@ -340,10 +284,10 @@ instance EvalOp NativeOp where z <- ifThenElse (ty, eq singleType inner (constant typerInt 0)) (pure x) (app2 f accX x) -- note: need to apply the accumulator as the _left_ argument to the function pure (Push Env.Empty $ FromArg (Value' (CJ z) sh), M.adjust (const (Exists z, eTy)) l acc) | otherwise = pure $ Push Env.Empty $ FromArg (Value' CN sh) - evalOp _ _ NScanl1 gamma (Push (Push _ (BAE (Value' x' sh) (BCAN2 (Just (BP{})) d))) (BAE f' _)) = error "backpermuted scan" - evalOp i l NFold1 gamma args = error "todo: fold1" + evalOp _ _ NScanl1 _ (Push (Push _ (BAE _ (BCAN2 (Just BP{}) _))) (BAE _ _)) = error "backpermuted scan" + evalOp _ _ NFold1 _ _ = error "todo: fold1" -- we can ignore the index permutation for folds here - evalOp (d',_,ixs) l NFold2 gamma (Push (Push _ (BAE (Value' x' sh@(Shape' (ShapeRsnoc shr') (CJ (OP_Pair sh' _)))) (BCAN2 _ d))) (BAE f' _)) + evalOp (d',_,ixs) l NFold2 gamma (Push (Push _ (BAE (Value' x' (Shape' (ShapeRsnoc shr') (CJ (OP_Pair sh _)))) (BCAN2 _ d))) (BAE f' _)) | f <- llvmOfFun2 @Native f' gamma , Lam (lhsToTupR -> ty :: TypeR e) _ <- f' , CJ x <- x' @@ -352,14 +296,13 @@ instance EvalOp NativeOp where = traceIfDebugging ("fold2work" <> show d') $ StateT $ \acc -> do let (Exists (unsafeCoerce @(Operands _) @(Operands e) -> accX), eTy) = acc M.! l z <- ifThenElse (ty, eq singleType inner (constant typerInt 0)) (pure x) (app2 f accX x) -- note: need to apply the accumulator as the _left_ argument to the function - pure (Push Env.Empty $ FromArg (Value' (CJ z) (Shape' shr' (CJ sh'))), M.adjust (const (Exists z, eTy)) l acc) - | f <- llvmOfFun2 @Native f' gamma - , Lam (lhsToTupR -> ty :: TypeR e) _ <- f' + pure (Push Env.Empty $ FromArg (Value' (CJ z) (Shape' shr' (CJ sh))), M.adjust (const (Exists z, eTy)) l acc) + | Lam (lhsToTupR -> _ty :: TypeR e) _ <- f' , d == d'+1 -- the fold was in the iteration above; we grab the result from the accumulator now = traceIfDebugging ("fold2done" <> show d') $ StateT $ \acc -> do let (Exists (unsafeCoerce @(Operands _) @(Operands e) -> x), _) = acc M.! l - pure (Push Env.Empty $ FromArg (Value' (CJ x) (Shape' shr' (CJ sh'))), acc) - | otherwise = pure $ Push Env.Empty $ FromArg (Value' CN (Shape' shr' (CJ sh'))) + pure (Push Env.Empty $ FromArg (Value' (CJ x) (Shape' shr' (CJ sh))), acc) + | otherwise = pure $ Push Env.Empty $ FromArg (Value' CN (Shape' shr' (CJ sh))) evalOp _ _ _ _ _ = error "unmatched pattern?" @@ -396,7 +339,7 @@ data Loopsizes where LS :: ShapeR sh -> Operands sh -> Loopsizes instance Show Loopsizes where - show (LS shr op) = "Loop of rank " <> show (rank shr) + show (LS shr _) = "Loop of rank " <> show (rank shr) merge :: Loopsizes -> GroundVars env sh -> Gamma env -> Loopsizes merge ls v gamma = combine ls $ LS (gvarsToShapeR v) (aprjParameters (unsafeToExpVars v) gamma) @@ -423,7 +366,7 @@ merge ls v gamma = combine ls $ LS (gvarsToShapeR v) (aprjParameters (unsafeToEx -- -- False -- _ -> error "huh" -- where flipShape x y = y - + -- error "todo: take the known indices from the smaller True one, and the rest from the larger False one, call the result False" instance EvalOp (JustAccumulator NativeOp) where @@ -440,16 +383,16 @@ instance EvalOp (JustAccumulator NativeOp) where subtup SubTupRskip _ = Both TupRunit OP_Unit subtup SubTupRkeep x = x subtup (SubTupRpair a b) (Both (TupRpair x y) (OP_Pair x' y')) = case (subtup @(JustAccumulator NativeOp) a (Both x x'), subtup @(JustAccumulator NativeOp) b (Both y y')) of - (Both l l', Both r r') -> Both (TupRpair l r) (OP_Pair l' r') + (Both l l', Both r r') -> Both (TupRpair l r) (OP_Pair l' r') subtup _ _ = error "subtup-pair with non-pair TypeR" - readInput ty sh _ gamma (BCA2JA IsUnit) _ = pure $ Both TupRunit OP_Unit - readInput ty sh _ gamma (BCA2JA (BCAN2 Nothing d)) _ = StateT $ \(acc,ls) -> pure (Both (TupRsingle ty) (zeroes $ TupRsingle ty), (acc, merge ls sh gamma)) - readInput ty sh _ gamma (BCA2JA (BCAN2 (Just (BP _ _ _ ls')) d)) _ = StateT $ \(acc,ls) -> pure (Both (TupRsingle ty) (zeroes $ TupRsingle ty), (acc, merge ls ls' gamma)) + readInput _ _ _ _ (BCA2JA IsUnit) _ = pure $ Both TupRunit OP_Unit + readInput ty sh _ gamma (BCA2JA (BCAN2 Nothing _)) _ = StateT $ \(acc,ls) -> pure (Both (TupRsingle ty) (zeroes $ TupRsingle ty), (acc, merge ls sh gamma)) + readInput ty _ _ gamma (BCA2JA (BCAN2 (Just (BP _ _ _ ls')) _)) _ = StateT $ \(acc,ls) -> pure (Both (TupRsingle ty) (zeroes $ TupRsingle ty), (acc, merge ls ls' gamma)) - writeOutput ty sh buf gamma ix x = StateT $ \(acc,ls) -> pure ((), (acc, merge ls sh gamma)) + writeOutput _ sh _ gamma _ _ = StateT $ \(acc,ls) -> pure ((), (acc, merge ls sh gamma)) - evalOp () l (JA NScanl1) _ (Push (Push _ (BAE (Value' (Both ty x) sh) _)) (BAE f _)) + evalOp () l (JA NScanl1) _ (Push (Push _ (BAE (Value' (Both ty x) sh) _)) (BAE _ _)) = StateT $ \(acc,ls) -> do let thing = zeroes ty pure (Push Env.Empty $ FromArg (Value' (Both ty x) sh), (M.insert l (Exists thing, Exists ty) acc, ls)) @@ -466,7 +409,7 @@ instance EvalOp (JustAccumulator NativeOp) where = lift $ pure $ Push Env.Empty $ FromArg $ Value' (Both (getOutputType f) (zeroes (getOutputType f))) (Shape' shr sh) evalOp _ _ (JA NPermute) _ (Push (Push (Push (Push (Push Env.Empty (BAE (Value' _ (Shape' shr (Both _ sh))) _)) _) _) _) _) = StateT $ \(acc,ls) -> pure (Env.Empty, (acc, combine ls $ LS shr sh)) - + getOutputType :: Fun env (i -> o) -> TypeR o getOutputType (Lam _ (Body e)) = expType e @@ -476,12 +419,13 @@ instance MakesILP op => MakesILP (JustAccumulator op) where type BackendVar (JustAccumulator op) = BackendVar op type BackendArg (JustAccumulator op) = BackendArg op newtype BackendClusterArg (JustAccumulator op) a = BCAJA (BackendClusterArg op a) - mkGraph (JA op) = undefined -- do not want to run the ILP solver again! + mkGraph (JA _) = undefined -- do not want to run the ILP solver again! finalize = undefined labelLabelledArg = undefined getClusterArg = undefined encodeBackendClusterArg = undefined combineBackendClusterArg = undefined + defaultBA = undefined -- probably just never running anything here -- this is really just here because MakesILP is a superclass @@ -500,12 +444,15 @@ instance (StaticClusterAnalysis op, EnvF (JustAccumulator op) ~ EnvF op) => Stat shToValue x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ shToValue $ coerce x varToValue x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ varToValue $ coerce x varToSh x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ varToSh $ coerce x + outToVar x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ outToVar $ coerce x shToVar x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ shToVar $ coerce x shrinkOrGrow a b x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ shrinkOrGrow a b $ coerce x addTup x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ addTup $ coerce x unitToVar x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ unitToVar $ coerce x varToUnit x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ varToUnit $ coerce x pairinfo a x y = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ pairinfo a (coerce x) (coerce y) + bcaid x = coerce @(BackendClusterArg op _) @(BackendClusterArg (JustAccumulator op) _) $ bcaid $ coerce x + deriving instance (Eq (BackendClusterArg2 op x y)) => Eq (BackendClusterArg2 (JustAccumulator op) x y) deriving instance (Show (BackendClusterArg2 op x y)) => Show (BackendClusterArg2 (JustAccumulator op) x y) diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs index f5847a31f..559cb196a 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs @@ -37,10 +37,10 @@ import Data.Array.Accelerate.Eval import qualified Data.Set as Set -import Data.Array.Accelerate.AST.Environment (weakenId, weakenEmpty, (.>), weakenSucc' ) +import Data.Array.Accelerate.AST.Environment (weakenId, weakenEmpty, weakenSucc' ) import Data.Array.Accelerate.Representation.Array (ArrayR(..)) import Data.Array.Accelerate.Trafo.Var (DeclareVars(..), declareVars) -import Data.Array.Accelerate.Representation.Ground (buffersR, typeRtoGroundsR) +import Data.Array.Accelerate.Representation.Ground (buffersR) import Data.Array.Accelerate.AST.LeftHandSide import Data.Array.Accelerate.Trafo.Operation.Substitution (aletUnique, alet, weaken, LHS (..), mkLHS) import Data.Array.Accelerate.Representation.Shape (ShapeR (..), shapeType, rank) @@ -57,11 +57,9 @@ import qualified Data.Map as M import Data.Array.Accelerate.Trafo.Exp.Substitution import Data.Array.Accelerate.Trafo.Desugar (desugarAlloc) -import qualified Debug.Trace -import GHC.Stack import Data.Array.Accelerate.AST.Idx (Idx(..)) import Data.Array.Accelerate.Pretty.Operation (prettyFun) -import Data.Array.Accelerate.Pretty.Exp (PrettyEnv(..), Val (Push)) +import Data.Array.Accelerate.Pretty.Exp (Val (Push)) import Prettyprinter (pretty) import Unsafe.Coerce (unsafeCoerce)