From 51293533ff8b3ec729ffd1c02595caccf3bd9354 Mon Sep 17 00:00:00 2001 From: David van Balen Date: Wed, 29 May 2024 12:07:21 +0200 Subject: [PATCH] current --- .../Array/Accelerate/LLVM/Native/CodeGen.hs | 22 +++++----- .../Accelerate/LLVM/Native/CodeGen/Loop.hs | 3 +- .../LLVM/Native/Execute/Scheduler.hs | 4 +- accelerate-llvm-native/test/nofib/Main.hs | 42 +++++++++++++------ 4 files changed, 44 insertions(+), 27 deletions(-) diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs index 9f3611c25..421d33f9b 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen.hs @@ -88,7 +88,7 @@ import Data.Array.Accelerate.Backend (SLVOperation(..)) import Data.Array.Accelerate.LLVM.CodeGen.IR - +traceIfDebugging str a = a --Debug.Trace.trace str a codegen :: ShortByteString -> Env AccessGroundR env @@ -105,7 +105,7 @@ codegen name env (Clustered c b) args = (acc, loopsize) <- execStateT (evalCluster (toOnlyAcc c) b' args gamma ()) (mempty, LS ShapeRz OP_Unit) -- body acc loopsize' acc' <- operandsMapToPairs acc $ \(accTypeR, toOp, fromOp) -> fmap fromOp $ flip execStateT (toOp acc) $ case loopsize of - LS loopshr loopsh -> + LS loopshr loopsh -> -- Debug.Trace.traceShow (rank loopshr) $ workstealChunked loopshr workstealIndex workstealActiveThreads (flipShape loopshr loopsh) accTypeR (body loopshr toOp fromOp, -- the LoopWork StateT $ \op -> second toOp <$> runStateT (foo (liftInt 0) []) (fromOp op)) -- the action to run after the outer loop @@ -292,19 +292,19 @@ instance EvalOp NativeOp where -- evalOp _ _ _ _ _ = error "todo: add depth checks to all matches" evalOp (d,_,_) _ NMap gamma (Push (Push (Push Env.Empty (BAE _ _)) (BAE (Value' x' (Shape' shr sh)) (BCAN2 _ d'))) (BAE f _)) = lift $ Push Env.Empty . FromArg . flip Value' (Shape' shr sh) <$> case x' of - CJ x | d == d' -> CJ <$> app1 (llvmOfFun1 @Native f gamma) x + CJ x | d == d' -> CJ <$> (traceIfDebugging ("map" <> show d') $ app1 (llvmOfFun1 @Native f gamma) x) _ -> pure CN evalOp _ _ NBackpermute _ (Push (Push (Push Env.Empty (BAE (Shape' shr sh) _)) (BAE (Value' x _) _)) _) = lift $ pure $ Push Env.Empty $ FromArg $ Value' x $ Shape' shr sh evalOp (d',_,is) _ NGenerate gamma (Push (Push Env.Empty (BAE (Shape' shr (CJ sh)) _)) (BAE f (BCAN2 Nothing d))) - | shr `isAtDepth'` d' = lift $ Push Env.Empty . FromArg . flip Value' (Shape' shr (CJ sh)) . CJ <$> app1 (llvmOfFun1 @Native f gamma) (multidim shr is) - | d' == d = error $ "how come we didn't hit the case above?" <> show (d', d, rank shr) + | shr `isAtDepth'` d' = traceIfDebugging ("generate" <> show d') $ lift $ Push Env.Empty . FromArg . flip Value' (Shape' shr (CJ sh)) . CJ <$> app1 (llvmOfFun1 @Native f gamma) (multidim shr is) + -- | d' == d = error $ "how come we didn't hit the case above?" <> show (d', d, rank shr) | otherwise = pure $ Push Env.Empty $ FromArg $ Value' CN (Shape' shr (CJ sh)) evalOp (d',_,is) _ NGenerate gamma (Push (Push Env.Empty (BAE (Shape' shr sh) _)) (BAE f (BCAN2 (Just (BP shrO shrI idxTransform ls)) d))) | not $ shrO `isAtDepth'` d' = pure $ Push Env.Empty $ FromArg $ flip Value' (Shape' shr sh) CN | Just Refl <- matchShapeR shrI shr - = lift $ Push Env.Empty . FromArg . flip Value' (Shape' shr sh) . CJ <$> app1 (llvmOfFun1 @Native (compose f idxTransform) gamma) (multidim shrO is) + = traceIfDebugging ("generate" <> show d') $ lift $ Push Env.Empty . FromArg . flip Value' (Shape' shr sh) . CJ <$> app1 (llvmOfFun1 @Native (compose f idxTransform) gamma) (multidim shrO is) | otherwise = error "bp shapeR doesn't match generate's output" -- For Permute, we ignore all the BP info here and simply assume that there is none evalOp (d',_,is) _ NPermute gamma (Push (Push (Push (Push (Push Env.Empty @@ -315,7 +315,7 @@ instance EvalOp NativeOp where (BAE (flip (llvmOfFun2 @Native) gamma -> c) _)) -- combination function | CJ x <- x' , shrx `isAtDepth'` d' - = lift $ do + = traceIfDebugging ("permute" <> show d') $ lift $ do ix' <- app1 f (multidim shrx is) -- project element onto the destination array and (atomically) update when (isJust ix') $ do @@ -335,7 +335,7 @@ instance EvalOp NativeOp where , CJ x <- x' , d == d' , inner:_ <- ixs - = StateT $ \acc -> do + = traceIfDebugging ("scan" <> show d') $ StateT $ \acc -> do let (Exists (unsafeCoerce @(Operands _) @(Operands e) -> accX), eTy) = acc M.! l z <- ifThenElse (ty, eq singleType inner (constant typerInt 0)) (pure x) (app2 f accX x) -- note: need to apply the accumulator as the _left_ argument to the function pure (Push Env.Empty $ FromArg (Value' (CJ z) sh), M.adjust (const (Exists z, eTy)) l acc) @@ -349,18 +349,18 @@ instance EvalOp NativeOp where , CJ x <- x' , d == d' , inner:_ <- ixs - = StateT $ \acc -> do + = traceIfDebugging ("fold2work" <> show d') $ StateT $ \acc -> do let (Exists (unsafeCoerce @(Operands _) @(Operands e) -> accX), eTy) = acc M.! l z <- ifThenElse (ty, eq singleType inner (constant typerInt 0)) (pure x) (app2 f accX x) -- note: need to apply the accumulator as the _left_ argument to the function pure (Push Env.Empty $ FromArg (Value' (CJ z) (Shape' shr' (CJ sh'))), M.adjust (const (Exists z, eTy)) l acc) | f <- llvmOfFun2 @Native f' gamma , Lam (lhsToTupR -> ty :: TypeR e) _ <- f' , d == d'+1 -- the fold was in the iteration above; we grab the result from the accumulator now - = StateT $ \acc -> do + = traceIfDebugging ("fold2done" <> show d') $ StateT $ \acc -> do let (Exists (unsafeCoerce @(Operands _) @(Operands e) -> x), _) = acc M.! l pure (Push Env.Empty $ FromArg (Value' (CJ x) (Shape' shr' (CJ sh'))), acc) | otherwise = pure $ Push Env.Empty $ FromArg (Value' CN (Shape' shr' (CJ sh'))) - -- evalOp _ _ _ _ _ = error "unmatched pattern?" + evalOp _ _ _ _ _ = error "unmatched pattern?" instance TupRmonoid Operands where diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen/Loop.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen/Loop.hs index a5804cd65..685ccf3f8 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen/Loop.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/CodeGen/Loop.hs @@ -90,8 +90,9 @@ loopWorkFromTo shr start end extent tys (loopwork,finish) = do loopWorkFromTo' :: ShapeR sh -> Operands sh -> Operands sh -> Operands sh -> Operands Int -> [Operands Int] -> TypeR s -> LoopWork sh (StateT (Operands s) (CodeGen Native)) -> StateT (Operands s) (CodeGen Native) () loopWorkFromTo' ShapeRz OP_Unit OP_Unit OP_Unit _ _ _ LoopWorkZ = pure () -loopWorkFromTo' (ShapeRsnoc shr) (OP_Pair start' start) (OP_Pair end' end) (OP_Pair extent' _) linixprev ixs tys (LoopWorkSnoc lw foo) = do +loopWorkFromTo' (ShapeRsnoc shr) (OP_Pair start' start) (OP_Pair end' endMaybe) (OP_Pair extent' extent) linixprev ixs tys (LoopWorkSnoc lw foo) = do linix <- lift $ add numType start linixprev + end <- lift $ A.min singleType endMaybe extent StateT $ \s -> ((),) <$> Loop.iter (TupRpair typerInt typerInt) tys diff --git a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Scheduler.hs b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Scheduler.hs index 184066a94..e10c53744 100644 --- a/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Scheduler.hs +++ b/accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Scheduler.hs @@ -183,9 +183,9 @@ hireWorkers :: IO Workers hireWorkers = do nproc <- getNumProcessors ncaps <- getNumCapabilities - menv <- (readMaybe =<<) <$> lookupEnv "ACCELERATE_LLVM_NATIVE_THREADS" + -- menv <- (readMaybe =<<) <$> lookupEnv "ACCELERATE_LLVM_NATIVE_THREADS" - let nthreads = fromMaybe nproc menv + let nthreads = 1 --fromMaybe nproc menv workers <- hireWorkersOn [0 .. nthreads-1] return workers diff --git a/accelerate-llvm-native/test/nofib/Main.hs b/accelerate-llvm-native/test/nofib/Main.hs index 5dca7e766..a0227ebe8 100644 --- a/accelerate-llvm-native/test/nofib/Main.hs +++ b/accelerate-llvm-native/test/nofib/Main.hs @@ -30,11 +30,27 @@ import Data.Array.Accelerate.Data.Bits import Data.Array.Accelerate.Unsafe import Control.Concurrent -- import Quickhull + +loop :: [a] -> [a] +loop xs = xs Prelude.<> loop xs + main :: IO () main = do - -- let xs = fromList (Z :. 5 :. 10) [1 :: Int ..] - -- let ys = map (+1) $ + + let histogram :: Acc (Vector Int) -> Acc (Vector Int) + histogram xs = + let zeros = fill (constant (Z:.10)) 0 + ones = fill (shape xs) 1 + in + permute (+) zeros (\ix -> Just_ (I1 (xs!ix))) ones + + let xs = fromList (Z :. 50) $ loop $ [1 :: Int .. 9] Prelude.<> [2 .. 8] + + -- let ys = map (\x -> x*x) $ -- use xs + -- let zs = sum $ reshape (Z_ ::. 5 ::. 10) ys + -- let (zs, _) = A.unzip $ awhile (A.any (\(T2 a b) -> a <= 5) ) (map (\(T2 a b) -> T2 b (a+1))) ys + -- let f = map (*2) -- let program = awhile (map (A.>0) . asnd) (\(T2 a b) -> T2 (f a) (map (\x -> x - 1) b)) (T2 ys $ unit $ constant (100000 :: Int)) @@ -45,8 +61,8 @@ main = do -- then -1 -- else 0 :: Exp Double - -- putStrLn $ test @UniformScheduleFun @NativeKernel zs - -- print $ run @Native zs + putStrLn $ test @UniformScheduleFun @NativeKernel histogram + print $ runN @Native histogram xs -- let negatives = [ -- I3 211 154 98, @@ -71,15 +87,15 @@ main = do -- I3 243 172 14, -- I3 54 209 40] - let zs = generate (Z_ ::. constant 15 ::. constant 15 ::. constant 11) $ \(I3 x y z) -> T3 x y z - -- cond (Prelude.foldl1 (||) $ Prelude.map (== idx) negatives) - -- (-1) - -- $ cond (Prelude.foldl1 (||) $ Prelude.map (==idx) positives) - -- 1 - -- 0 :: Exp Double - -- let zs' = zs $ use $ fromList Z [11 :: Int] - putStrLn $ test @UniformScheduleFun @NativeKernel zs - print $ run @Native zs + -- let zs = generate (Z_ ::. constant 15 ::. constant 15 ::. constant 11) $ \(I3 x y z) -> T3 x y z + -- -- cond (Prelude.foldl1 (||) $ Prelude.map (== idx) negatives) + -- -- (-1) + -- -- $ cond (Prelude.foldl1 (||) $ Prelude.map (==idx) positives) + -- -- 1 + -- -- 0 :: Exp Double + -- -- let zs' = zs $ use $ fromList Z [11 :: Int] + -- putStrLn $ test @UniformScheduleFun @NativeKernel zs + -- print $ run @Native zs -- putStrLn "scan:" -- let f =