diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs index 244e8fe93..c5139cf36 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs @@ -480,6 +480,7 @@ instance (StaticClusterAnalysis op, EnvF (JustAccumulator op) ~ EnvF op) => Stat addTup x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ addTup $ coerce x unitToVar x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ unitToVar $ coerce x varToUnit x = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ varToUnit $ coerce x + pairinfo x y = coerce @(BackendClusterArg2 op _ _) @(BackendClusterArg2 (JustAccumulator op) _ _) $ pairinfo (coerce x) (coerce y) deriving instance (Eq (BackendClusterArg2 op x y)) => Eq (BackendClusterArg2 (JustAccumulator op) x y) deriving instance (Show (BackendClusterArg2 op x y)) => Show (BackendClusterArg2 (JustAccumulator op) x y) diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs index a2ab28580..e5eb01981 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Operation.hs @@ -59,6 +59,10 @@ import Data.Array.Accelerate.Trafo.Desugar (desugarAlloc) import qualified Debug.Trace import GHC.Stack import Data.Array.Accelerate.AST.Idx (Idx(..)) +import Data.Array.Accelerate.Pretty.Operation (prettyFun) +import Data.Array.Accelerate.Pretty.Exp (PrettyEnv(..), Val (Push)) +import Prettyprinter (pretty) +import Unsafe.Coerce (unsafeCoerce) data NativeOp t where NMap :: NativeOp (Fun' (s -> t) -> In sh s -> Out sh t -> ()) @@ -316,11 +320,16 @@ data IndexPermutation env where type IterationDepth = Int instance Show (BackendClusterArg2 NativeOp env arg) where show (BCAN2 i d) = "{ depth = " <> show d <> ", perm = " <> show i <> " }" + show IsUnit = "()" instance Show (IndexPermutation env) where - show (BP sh1 sh2 _ _) = show (rank sh1) <> "->" <> show (rank sh2) + show (BP sh1 sh2 f _) = show (rank sh1) <> "->" <> show (rank sh2) <> ": " <> show (prettyFun (infenv 0) f) + where + infenv i = unsafeCoerce $ infenv (i+1) `Push` (pretty $ "x"<>show i) instance StaticClusterAnalysis NativeOp where data BackendClusterArg2 NativeOp env arg where BCAN2 :: Maybe (IndexPermutation env) -> IterationDepth -> BackendClusterArg2 NativeOp env arg + IsUnit ::BackendClusterArg2 NativeOp env (m sh ()) -- units don't get backpermuted because they don't exist + def (ArgArray _ (ArrayR _ TupRunit) _ _) _ _ = IsUnit def _ _ (BCAN i) = BCAN2 Nothing i unitToVar = bcan2id varToUnit = bcan2id @@ -344,12 +353,17 @@ instance StaticClusterAnalysis NativeOp where onOp NBackpermute (BCAN2 Nothing d :>: ArgsNil) (ArgFun f :>: ArgArray In (ArrayR shrI _) _ _ :>: ArgArray Out (ArrayR shrO _) sh _ :>: ArgsNil) _ = BCAN2 Nothing d :>: BCAN2 (Just (BP shrO shrI f sh)) d :>: BCAN2 Nothing d :>: ArgsNil onOp NGenerate (bp :>: ArgsNil) (_:>:ArgArray Out (ArrayR shR _) _ _ :>:ArgsNil) _ = - bcan2id bp :>: bp :>: ArgsNil -- storing the bp in the function argument. Probably not required, could just take it from the array one during codegen + bcan2id bp :>: bp :>: ArgsNil -- store the bp in the function, because there is no input array onOp NPermute ArgsNil (_:>:_:>:_:>:_:>:ArgArray In (ArrayR shR _) _ _ :>:ArgsNil) _ = BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing 0 :>: BCAN2 Nothing (rank shR) :>: ArgsNil onOp NFold2 (bp :>: ArgsNil) (_ :>: ArgArray In _ fs _ :>: _ :>: ArgsNil) _ = BCAN2 Nothing 0 :>: fold2bp bp (case fs of TupRpair _ x -> x) :>: bp :>: ArgsNil onOp NFold1 (bp :>: ArgsNil) _ _ = BCAN2 Nothing 0 :>: fold1bp bp :>: bp :>: ArgsNil onOp NScanl1 (bp :>: ArgsNil) _ _ = BCAN2 Nothing 0 :>: bcan2id bp :>: bp :>: ArgsNil + pairinfo IsUnit x = shrinkOrGrow x + pairinfo x IsUnit = shrinkOrGrow x + pairinfo a b = if shrinkOrGrow a == b then shrinkOrGrow a else error $ "pairing unequal: " <> show a <> ", " <> show b + + bcan2id :: BackendClusterArg2 NativeOp env arg -> BackendClusterArg2 NativeOp env arg' bcan2id (BCAN2 Nothing i) = BCAN2 Nothing i @@ -372,10 +386,13 @@ fold2bp (BCAN2 (Just (BP shr1 shr2 g sh)) i) foldsize = flip BCAN2 (i+1) $ Just (TupRpair sh foldsize) instance Eq (BackendClusterArg2 NativeOp env arg) where + IsUnit == IsUnit = True x@(BCAN2 p i) == y@(BCAN2 p' i') = f $ p == p' && i == i' where f True = True f False = False + _ == _ = False + instance Eq (IndexPermutation env) where (BP shr1 shr2 f _) == (BP shr1' shr2' f' _) | Just Refl <- matchShapeR shr1 shr1' diff --git a/accelerate-llvm-native/test/nofib/Main.hs b/accelerate-llvm-native/test/nofib/Main.hs index 329977114..1b2f95e99 100644 --- a/accelerate-llvm-native/test/nofib/Main.hs +++ b/accelerate-llvm-native/test/nofib/Main.hs @@ -32,14 +32,19 @@ import Data.Array.Accelerate.Unsafe main :: IO () main = do let xs = fromList (Z :. 10) [1 :: Int ..] - let ys = use xs - - - let f = T2 (map (+1) ys) (map (*2) $ reverse ys) - -- let f = --map (\(T2 a b) -> a + b) $ - -- zip ys $ reverse ys - putStrLn $ test @UniformScheduleFun @NativeKernel f - print $ run @Native f + let ys = map (\x -> T2 x x) $ + use xs + + + -- let f = T2 (map (+1) ys) (map (*2) $ reverse ys) + -- let f = sum $ map (\(T2 a b) -> a + b) $ + -- zip (reverse $ map (+1) (reverse ys)) $ reverse ys + let Z_ ::. n = shape ys + let f'' = backpermute (Z_ ::. 5 ::. 2) (\(I2 x y) -> I1 (x*y)) ys + let f' = replicate (Z_ ::. All_ ::. n) ys + let f = zip (reverse ys) ys + putStrLn $ test @UniformScheduleFun @NativeKernel $ backpermute (Z_ ::. 5) (\x->x) (reverse ys) + -- print $ run @Native $ f -- putStrLn "generate:" -- let f = generate (I1 10) (\(I1 x0) -> 10 :: Exp Int)