Skip to content

Commit

Permalink
fixed crashes related to operations on tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
dpvanbalen committed Apr 30, 2024
1 parent d40b81e commit 3ce8345
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 48 deletions.
41 changes: 31 additions & 10 deletions src/Data/Array/Accelerate/AST/Partitioned.hs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ import Data.Array.Accelerate.Trafo.Operation.LiveVars
import Data.Maybe (fromJust, Maybe (Nothing))
import Data.Array.Accelerate.AST.Var (varsType)
import qualified Debug.Trace
import Data.Array.Accelerate.Error (HasCallStack)
import Data.Array.Accelerate.Analysis.Match (matchShapeR)



Expand All @@ -73,18 +75,33 @@ type PartitionedAfun op = PreOpenAfun (Clustered op)
data Clustered op args = Clustered (Cluster op args) (BackendCluster op args)

data Cluster op args where
Op :: SLVOp op args -> Label -> Cluster op args
SingleOp :: SingleOp op args -> Label -> Cluster op args
Fused :: Fusion largs rargs args
-> Cluster op largs
-> Cluster op rargs
-> Cluster op args

data SingleOp op args where
Single :: op args
-> SOAs args expanded
-> SortedArgs expanded sorted
-> SubArgs sorted live
-> SingleOp op live

-- this pattern synonym translates between the 'new' ast (above) and the 'old' ast (below):
-- the first constructor of `Cluster` used to be called `Op` and contain an `SLVOp`.
-- todo: slowly change all use sites from the old to the new, and eventually retire the old and this pattern synonym.
{-# COMPLETE Op, Fused #-}
pattern Op :: SLVOp op args -> Label -> Cluster op args
pattern Op slv l <- SingleOp (toOld -> slv) l where
Op (SLV (SOp (SOAOp op soas) sortedargs) subargs) l = SingleOp (Single op soas sortedargs subargs) l
toOld (Single op soas sortedargs subargs) = SLV (SOp (SOAOp op soas) sortedargs) subargs

data SLVOp op args where
SLV :: SortedOp op big
-> SubArgs big small
-> SLVOp op small

-- a wrapper around operations which sorts the (term and type level) arguments on global labels, to line the arguments up for Fusion
data SortedOp op args where
SOp :: SOAOp op args
-> SortedArgs args sorted
Expand All @@ -95,7 +112,6 @@ data SortedArgs args sorted where
-> (forall f. PreArgs f sorted -> PreArgs f args)
-> SortedArgs args sorted

-- a wrapper around operations for SOA: each individual buffer from an argument array may fuse differently
data SOAOp op args where
SOAOp :: op args
-> SOAs args expanded
Expand All @@ -110,6 +126,7 @@ data SOA arg appendto result where
SOArgTup :: SOA (f right) args args' -> SOA (f left) args' args'' -> SOA (f (left,right)) args args''



-- These correspond to the inference rules in the paper
-- datatype describing how to combine the arguments of two fused clusters
data Fusion largs rars args where
Expand Down Expand Up @@ -239,9 +256,9 @@ left' k (IntroR f) (_ :>:args) = left' k f args

right :: Fusion largs rargs args -> Args env args -> Args env rargs
right = right' varToIn outToIn
right' :: (forall sh e. f (Var' sh) -> f (In sh e)) -> (forall sh e. f (Out sh e) -> f (In sh e)) -> Fusion largs rargs args -> PreArgs f args -> PreArgs f rargs
right' :: (forall sh e. ArrayR (Array sh e) -> f (Var' sh) -> f (In sh e)) -> (forall sh e. f (Out sh e) -> f (In sh e)) -> Fusion largs rargs args -> PreArgs f args -> PreArgs f rargs
right' vi oi EmptyF ArgsNil = ArgsNil
right' vi oi (Vertical arr f) (arg :>: args) = vi arg :>: right' vi oi f args
right' vi oi (Vertical arr f) (arg :>: args) = vi arr arg :>: right' vi oi f args
right' vi oi (Diagonal f) (arg :>: args) = oi arg :>: right' vi oi f args
right' vi oi (Horizontal f) (arg :>: args) = arg :>: right' vi oi f args
right' vi oi (IntroI1 f) (_ :>: args) = right' vi oi f args
Expand All @@ -251,14 +268,16 @@ right' vi oi (IntroO2 f) (arg :>: args) = arg :>: right' vi oi f args
right' vi oi (IntroL f) (_ :>: args) = right' vi oi f args
right' vi oi (IntroR f) (arg :>: args) = arg :>: right' vi oi f args

varToIn :: Arg env (Var' sh) -> Arg env (In sh e)
varToIn (ArgVar sh) = ArgArray In (ArrayR (varsToShapeR sh) er) (mapTupR (\(Var t ix)->Var (GroundRscalar t) ix) sh) er
varToIn :: ArrayR (Array sh e) -> Arg env (Var' sh) -> Arg env (In sh e)
varToIn (ArrayR shr ty) (ArgVar sh)
| Just Refl <- matchShapeR shr (varsToShapeR sh) = ArgArray In (ArrayR shr ty) (mapTupR (\(Var t ix)->Var (GroundRscalar t) ix) sh) er
| otherwise = error "wrong shape?"
where er = error "accessing fused away array"
outToIn :: Arg env (Out sh e) -> Arg env (In sh e)
outToIn (ArgArray Out x y z) = ArgArray In x y z
inToOut :: Arg env (In sh e) -> Arg env (Out sh e)
inToOut (ArgArray In x y z) = ArgArray Out x y z
varToOut = inToOut . varToIn
varToOut arr = inToOut . varToIn arr

both :: (forall sh e. f (Out sh e) -> f (In sh e) -> f (Var' sh)) -> Fusion largs rargs args -> PreArgs f largs -> PreArgs f rargs -> PreArgs f args
both _ EmptyF ArgsNil ArgsNil = ArgsNil
Expand Down Expand Up @@ -390,10 +409,12 @@ instance TupRmonoid (TupR f) where
unOpLabels' :: LabelledArgsOp op env args -> LabelledArgs env args
unOpLabels' = mapArgs $ \(LOp arg l _) -> L arg l

data Both f g a = Both (f a) (g a)
data Both f g a = Both (f a) (g a) deriving (Show, Eq)
fst' (Both x _) = x
snd' (Both _ y) = y



zipArgs :: PreArgs f a -> PreArgs g a -> PreArgs (Both f g) a
zipArgs ArgsNil ArgsNil = ArgsNil
zipArgs (f:>:fs) (g:>:gs) = Both f g :>: zipArgs fs gs
Expand Down Expand Up @@ -582,7 +603,7 @@ outvar :: Arg env (Out sh e) -> Arg env (Var' sh)
outvar (ArgArray Out (ArrayR shr _) sh _) = ArgVar $ groundToExpVar (shapeType shr) sh

varout :: Arg env (Var' sh) -> Arg env (Out sh e)
varout (ArgVar sh) = ArgArray Out (ArrayR (fromJust $ typeShape $ varsType sh) undefined) (expToGroundVar sh) undefined
varout (ArgVar sh) = ArgArray Out (ArrayR (fromJust $ typeShape $ varsType sh) (error "fake")) (expToGroundVar sh) (error "fake")


slv :: (forall sh e. f (Out sh e) -> f (Var' sh)) -> SubArgs args args' -> PreArgs f args -> PreArgs f args'
Expand Down
86 changes: 53 additions & 33 deletions src/Data/Array/Accelerate/Eval.hs
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ class ( MakesILP op
varToValue :: BackendClusterArg2 op env (Var' sh) -> BackendClusterArg2 op env (Value sh e)
varToSh :: BackendClusterArg2 op env (Var' sh) -> BackendClusterArg2 op env (Sh sh e)
shToVar :: BackendClusterArg2 op env (Sh sh e) -> BackendClusterArg2 op env (Var' sh )
shrinkOrGrow :: BackendClusterArg2 op env (m sh e') -> BackendClusterArg2 op env (m sh e)
shrinkOrGrow :: Arg env (n sh e') -> Arg env (n sh e) -> BackendClusterArg2 op env (m sh e') -> BackendClusterArg2 op env (m sh e)
addTup :: BackendClusterArg2 op env (v sh e) -> BackendClusterArg2 op env (v sh ((), e))
unitToVar :: BackendClusterArg2 op env (m sh ()) -> BackendClusterArg2 op env (Var' sh )
varToUnit :: BackendClusterArg2 op env (Var' sh) -> BackendClusterArg2 op env (m sh ())
inToVar :: BackendClusterArg2 op env (In sh e) -> BackendClusterArg2 op env (Var' sh )
pairinfo :: BackendClusterArg2 op env (m sh a) -> BackendClusterArg2 op env (m sh b) -> BackendClusterArg2 op env (m sh (a,b))
pairinfo :: Arg env (n sh (a,b)) -> BackendClusterArg2 op env (m sh a) -> BackendClusterArg2 op env (m sh b) -> BackendClusterArg2 op env (m sh (a,b))
-- pairinfo a b = if shrinkOrGrow a == b then shrinkOrGrow a else error $ "pairing unequal: " <> show a <> ", " <> show b
unpairinfo :: BackendClusterArg2 op env (m sh (a,b)) -> (BackendClusterArg2 op env (m sh a), BackendClusterArg2 op env (m sh b))
unpairinfo x = (shrinkOrGrow x, shrinkOrGrow x)
unpairinfo :: Arg env (n sh (a,b)) -> BackendClusterArg2 op env (m sh (a,b)) -> (BackendClusterArg2 op env (m sh a), BackendClusterArg2 op env (m sh b))
unpairinfo arg@(ArgArray m (ArrayR shr (TupRpair a b)) sh (TupRpair a' b')) x =
( shrinkOrGrow arg (ArgArray m (ArrayR shr a) sh a') x
, shrinkOrGrow arg (ArgArray m (ArrayR shr b) sh b') x)

foo :: StaticClusterAnalysis op => SubArgs big small -> Args env small -> BackendArgs op env (OutArgsOf small) -> FEnv op env -> BackendCluster op small -> BackendArgs op env (OutArgsOf big)
foo SubArgsNil ArgsNil ArgsNil _ _ = ArgsNil
Expand All @@ -102,33 +104,51 @@ foo (SubArgKeep `SubArgsLive` subargs) (a:>:as) bs env (_ :>: cs) = cas
ArgVar _ -> foo subargs as bs env cs
ArgExp _ -> foo subargs as bs env cs
ArgFun _ -> foo subargs as bs env cs
foo (SubArgOut _ `SubArgsLive` subargs) (_ :>: as) (b:>:bs) env (_ :>: cs) = shrinkOrGrow b :>: foo subargs as bs env cs
foo (SubArgOut s `SubArgsLive` subargs) (a :>: as) (b:>:bs) env (c :>: cs) = shrinkOrGrow a (grow' s a) b :>: foo subargs as bs env cs
foo (SubArgsDead subargs) (a :>: as) bs env (c :>: cs) =
shToOut (varToSh (def a env c)) :>: foo subargs as bs env cs

grow' :: SubTupR big small -> Arg env (m sh small) -> Arg env (m sh big)
grow' SubTupRskip (ArgArray m (ArrayR shr ty) sh buf) = ArgArray m (ArrayR shr (TupRsingle $ error "fused away output")) sh (TupRsingle $ error "fused away output")
grow' SubTupRkeep a = a
grow' (SubTupRpair l r) a = error "todo"

makeBackendArg :: forall op env args. StaticClusterAnalysis op => Args env args -> FEnv op env -> Cluster op args -> BackendCluster op args -> BackendArgs op env args
makeBackendArg args env c b = go args c (defaultOuts args b) b
where
go :: forall args. Args env args -> Cluster op args -> BackendArgs op env (OutArgsOf args) -> BackendCluster op args -> BackendArgs op env args
go args (Fused f l r) outputs bs = let
backR = go (right f args) r (rightB args f outputs) (right' bcaid bcaid f bs)
backR = go (right f args) r (rightB args f outputs) (right' (const bcaid) bcaid f bs)
backL = go (left f args) l (backleft f backR outputs) (left' bcaid f bs)
in fuseBack f backL backR
go args (Op (SLV (SOp (SOAOp op soa) (SA sort unsort)) subargs) _l) outputs bs =
slv outToVar subargs
. sort
. soaExpand uncombineB soa
. mapArgs snd'
. soaExpand uncombineB' soa
. zipArgs (soaShrink combine soa . unsort $ slv' varout subargs args)
$ onOp @op
op
(forgetIn (soaShrink combine soa . unsort $ slv' varout subargs args)
. soaShrink combineB soa . unsort $ inventIn (slv' varout subargs args) (foo subargs args outputs env bs))
. mapArgs snd'
. soaShrink combineB' soa . unsort $ zipArgs (slv' varout subargs args) $ inventIn (slv' varout subargs args) (foo subargs args outputs env bs))
(soaShrink combine soa . unsort $ slv' varout subargs args)
env

combineB :: BackendClusterArg2 op env (f l) -> BackendClusterArg2 op env (f r) -> BackendClusterArg2 op env (f (l,r))
combineB = unsafeCoerce $ pairinfo @op
uncombineB :: BackendClusterArg2 op env (f (l,r)) -> (BackendClusterArg2 op env (f l), BackendClusterArg2 op env (f r))
combineB :: Arg env (g (l,r)) -> BackendClusterArg2 op env (f l) -> BackendClusterArg2 op env (f r) -> BackendClusterArg2 op env (f (l,r))
combineB = unsafeCoerce $ pairinfo @op
uncombineB :: Arg env (g (l,r)) -> BackendClusterArg2 op env (f (l,r)) -> (BackendClusterArg2 op env (f l), BackendClusterArg2 op env (f r))
uncombineB = unsafeCoerce $ unpairinfo @op
combineB' :: Both (Arg env) (BackendClusterArg2 op env) (g l)
-> Both (Arg env) (BackendClusterArg2 op env) (g r)
-> Both (Arg env) (BackendClusterArg2 op env) (g (l, r))
combineB' (Both al l) (Both ar r) = Both (combine al ar) (combineB (combine al ar) l r)
uncombineB' :: Both (Arg env) (BackendClusterArg2 op env) (g (l, r))
-> (Both (Arg env) (BackendClusterArg2 op env) (g l),
Both (Arg env) (BackendClusterArg2 op env) (g r))
uncombineB' (Both a x) = let (al, ar) = split a
(l, r) = uncombineB a x
in (Both al l, Both ar r)

defaultOuts :: Args env args -> BackendCluster op args -> BackendArgs op env (OutArgsOf args)
defaultOuts args backendcluster = forgetIn args $ fuseArgsWith args backendcluster (\arg b -> def arg env b)
Expand All @@ -155,7 +175,7 @@ makeBackendArg args env c b = go args c (defaultOuts args b) b
-> Fusion largs rargs args
-> BackendArgs op env (OutArgsOf args)
-> BackendArgs op env (OutArgsOf rargs)
rightB args f = forgetIn (right f args) . right' (valueToIn . varToValue) (valueToIn . outToValue) f . inventIn args
rightB args f = forgetIn (right f args) . right' (const $ valueToIn . varToValue) (valueToIn . outToValue) f . inventIn args

backleft :: forall largs rargs args
. StaticClusterAnalysis op
Expand Down Expand Up @@ -230,8 +250,8 @@ splitFromArg' :: (EvalOp op) => FromArg' op env (Value sh (l,r)) -> (FromArg' op
splitFromArg' (FromArg v) = bimap FromArg FromArg $ unpair' v

pairInArg :: (EvalOp op) => Arg env (arg (l,r)) -> BackendArgEnvElem op env (InArg (arg l)) -> BackendArgEnvElem op env (InArg (arg r)) -> BackendArgEnvElem op env (InArg (arg (l,r)))
pairInArg (ArgArray In _ _ _) l r = getCompose $ pair' (Compose l) (Compose r)
pairInArg (ArgArray Out _ _ _) l r = getCompose $ pair' (Compose l) (Compose r)
pairInArg a@(ArgArray In _ _ _) (BAE x b) (BAE y d) = BAE (pair' x y) (pairinfo a b d)
pairInArg a@(ArgArray Out _ _ _) (BAE x b) (BAE y d) = BAE (pair' x y) (pairinfo a b d)
pairInArg _ _ _ = error "SOA'd non-array args"

evalCluster :: (EvalOp op) => Cluster op args -> BackendCluster op args -> Args env args -> FEnv op env -> Index op -> EvalMonad op ()
Expand Down Expand Up @@ -382,25 +402,25 @@ instance TupRmonoid (Sh' op sh) where
unpair' (Shape' shr sh) = (Shape' shr sh, Shape' shr sh)


instance EvalOp op => TupRmonoid (Compose (BackendArgEnvElem op env) (Value sh)) where
pair' (Compose (BAE x b)) (Compose (BAE y d)) =
Compose $ BAE (pair' x y) (pairinfo b d)
unpair' (Compose (BAE x b)) =
biliftA2
(Compose .* BAE)
(Compose .* BAE)
(unpair' x)
(unpairinfo b)

instance EvalOp op => TupRmonoid (Compose (BackendArgEnvElem op env) (Sh sh)) where
pair' (Compose (BAE x b)) (Compose (BAE y d)) =
Compose $ BAE (pair' x y) (pairinfo b d)
unpair' (Compose (BAE x b)) =
biliftA2
(Compose .* BAE)
(Compose .* BAE)
(unpair' x)
(unpairinfo b)
-- instance EvalOp op => TupRmonoid (Compose (BackendArgEnvElem op env) (Value sh)) where
-- pair' (Compose (BAE x b)) (Compose (BAE y d)) =
-- Compose $ BAE (pair' x y) (Debug.Trace.trace "pair'" $ pairinfo b d)
-- unpair' (Compose (BAE x b)) =
-- biliftA2
-- (Compose .* BAE)
-- (Compose .* BAE)
-- (unpair' x)
-- (unpairinfo _ b)

-- instance EvalOp op => TupRmonoid (Compose (BackendArgEnvElem op env) (Sh sh)) where
-- pair' (Compose (BAE x b)) (Compose (BAE y d)) =
-- Compose $ BAE (pair' x y) (pairinfo b d)
-- unpair' (Compose (BAE x b)) =
-- biliftA2
-- (Compose .* BAE)
-- (Compose .* BAE)
-- (unpair' x)
-- (unpairinfo _ b)



Expand Down
4 changes: 2 additions & 2 deletions src/Data/Array/Accelerate/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ instance StaticClusterAnalysis InterpretOp where
varToValue (BCA f sz) = BCA f sz
varToSh (BCA f sz) = BCA f sz
shToVar (BCA f sz) = BCA f sz
shrinkOrGrow (BCA f sz) = BCA f sz
shrinkOrGrow _ _ (BCA f sz) = BCA f sz
addTup (BCA f sz) = BCA f sz
unitToVar (BCA f sz) = BCA f sz
varToUnit (BCA f sz) = BCA f sz
Expand Down Expand Up @@ -675,7 +675,7 @@ iterationsize (Op _ _) ((BCA _ n) :>: args) = if n==0 then iterationsize (Op und
iterationsize (P.Fused f l r) b =
let lsz = iterationsize l (left' (\(BCA f x) -> BCA f x) f b)
in if lsz == 0
then iterationsize r (right' (\(BCA f x)->BCA f x) (\(BCA f x)->BCA f x) f b)
then iterationsize r (right' (\_ (BCA f x)->BCA f x) (\(BCA f x)->BCA f x) f b)
else lsz

-- iterationsize (Op _) ArgsNil env = Nothing
Expand Down
2 changes: 1 addition & 1 deletion src/Data/Array/Accelerate/Pretty/Partitioned.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ instance PrettyOp op => PrettyOp (Clustered op) where
prettyOpWithArgs env (Clustered c _) = prettyOpWithArgs env c

instance PrettyOp op => PrettyOp (Cluster op) where
prettyOp (Fused _ l r) = "Fused (" <> prettyOp l <> ", " <> prettyOp r
prettyOp (Fused _ l r) = "Fused" -- (" <> prettyOp l <> ", " <> prettyOp r
prettyOp (Op (SLV (SOp (SOAOp op _) _) _) _) = prettyOp op
prettyOpWithArgs env (Fused f l r) args = "Fused" -- (" <> prettyOpWithArgs env l (left f args) <> ", " <> prettyOpWithArgs env r (right f args)
prettyOpWithArgs env (Op (SLV (SOp (SOAOp op soa) (SA _ unsort)) subargs) _) args =
Expand Down
4 changes: 2 additions & 2 deletions src/Data/Array/Accelerate/Trafo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ testWithObjective obj f
++ Pretty.renderForTerminal (Pretty.prettyAfun operation)
++ "\n\nPartitionedAcc:\n"
++ Pretty.renderForTerminal (Pretty.prettyAfun partitioned)
++ "\nSLV'd PartitionedAcc:\n"
++ Pretty.renderForTerminal (Pretty.prettyAfun slvpartitioned)
-- ++ "\nSLV'd PartitionedAcc:\n"
-- ++ Pretty.renderForTerminal (Pretty.prettyAfun slvpartitioned)
++ "\n\nSchedule:\n"
++ Pretty.renderForTerminal (Pretty.prettySchedule schedule)
where
Expand Down

0 comments on commit 3ce8345

Please sign in to comment.