Skip to content

Commit

Permalink
fix permute on target arrays of tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
dpvanbalen committed Mar 19, 2024
1 parent 8f02840 commit 0467ac4
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/Data/Array/Accelerate/AST/Partitioned.hs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ data SOA arg appendto result where

soaShrink :: forall args expanded f
. (forall a. Show (f a))
=> (forall l r g. f (g l) -> f (g r) -> f (g (l,r)))
=> (forall l r g. f (g l) -> f (g r) -> f (g (l,r)))
-> SOAs args expanded -> PreArgs f expanded -> PreArgs f args
soaShrink _ SOArgsNil ArgsNil = ArgsNil
soaShrink f (SOArgsCons soas soa) args = case go soa args of (arg :>: args') -> arg :>: soaShrink f soas args'
Expand Down Expand Up @@ -212,7 +212,7 @@ data Cluster op args where
-- datatype describing how to combine the arguments of two fused clusters
data Fusion largs rars args where
EmptyF :: Fusion () () ()
Vertical :: ArrayR (Array sh e)
Vertical :: ArrayR (Array sh e)
-> Fusion l r a
-> Fusion (Out sh e -> l) (In sh e -> r) (Var' sh -> a)
Horizontal :: Fusion l r a
Expand Down Expand Up @@ -407,12 +407,12 @@ onZipped :: (f a -> f b -> f c) -> (g a -> g b -> g c) -> (Both f g a -> Both f
onZipped f g (Both fa ga) (Both fb gb) = Both (f fa fb) (g ga gb)

-- assumes that the arguments are already sorted!
fuse :: MakesILP op
fuse :: MakesILP op
=> LabelledArgsOp op env l -> LabelledArgsOp op env r -> PreArgs f l -> PreArgs f r -> Clustered op l -> Clustered op r
-> (forall sh e. f (Out sh e) -> f (In sh e) -> f (Var' sh))
-> (forall args. PreArgs f args -> Clustered op args -> result)
-> result
fuse labl labr largs rargs (Clustered cl bl) (Clustered cr br) c k = fuse' labl labr (zipArgs largs bl) (zipArgs rargs br) cl cr (onZipped c combineBackendClusterArg)
fuse labl labr largs rargs (Clustered cl bl) (Clustered cr br) c k = fuse' labl labr (zipArgs largs bl) (zipArgs rargs br) cl cr (onZipped c combineBackendClusterArg)
$ \args c' -> k (mapArgs fst' args) (Clustered c' $ mapArgs snd' args)

-- assumes that the arguments are already sorted!
Expand All @@ -435,8 +435,10 @@ mkFused (l'@(LOp l ((Arr (TupRsingle (C.Const (ELabel llab))), lls))lop) :>: ls)
| (llab,lop) < (rlab,rop) = mkFused ls (r':>:rs) $ \f -> k (addleft l f)
| (llab,lop) > (rlab,rop) = mkFused (l':>:ls) rs $ \f -> k (addright r f)
| otherwise = error "simple math, the truth cannot be questioned"
mkFused ((LOp _ ((Arr TupRpair{}, _))_) :>: _) _ _ = error "not soa'd array"
mkFused _ ((LOp _ ((Arr TupRpair{}, _))_) :>: _) _ = error "not soa'd array"
mkFused ((LOp l@(ArgArray Mut _ _ _) _ _) :>: ls) rs k = mkFused ls rs $ \f -> k (addleft l f)
mkFused ls ((LOp r@(ArgArray Mut _ _ _) _ _) :>: rs) k = mkFused ls rs $ \f -> k (addright r f)
mkFused ((LOp _ (Arr TupRpair{}, _)_) :>: _) _ _ = error "not soa'd array"
mkFused _ ((LOp _ (Arr TupRpair{}, _)_) :>: _) _ = error "not soa'd array"

addleft :: Arg env arg -> Fusion left right args -> Fusion (arg->left) right (arg->args)
addleft (ArgVar _ ) f = IntroL f
Expand Down Expand Up @@ -469,7 +471,7 @@ singleton l largs op k = mkSOAs (mapArgs (\(LOp a _ _) -> a) largs) $ \soas ->
-- (subargsId $ sort $ soaExpand splitLabelledArgsOp soas largs)

sortArgs :: LabelledArgs env args -> (forall sorted. SortedArgs args sorted -> r) -> r
sortArgs args k =
sortArgs args k =
-- if nub ls /= ls then error "not sure if this works" -- this means that some arguments have identical sets of labels. This should only be a problem if two array arguments share labelsets.
-- else
k $ SA
Expand All @@ -489,11 +491,11 @@ mkSOAs ArgsNil k = k SOArgsNil
mkSOAs (a:>:args) k = mkSOAs args $ \soas -> mkSOA a $ \soa -> k (SOArgsCons soas soa)

mkSOA :: Arg env arg -> (forall result. SOA arg toappend result -> r) -> r
mkSOA (ArgArray In (ArrayR shr (TupRpair tl tr)) sh (TupRpair bufl bufr)) k =
mkSOA (ArgArray In (ArrayR shr (TupRpair tl tr)) sh (TupRpair bufl bufr)) k =
mkSOA (ArgArray In (ArrayR shr tr ) sh bufr ) $ \soar ->
mkSOA (ArgArray In (ArrayR shr tl ) sh bufl ) $ \soal ->
k (SOArgTup soar soal)
mkSOA (ArgArray Out (ArrayR shr (TupRpair tl tr)) sh (TupRpair bufl bufr)) k =
mkSOA (ArgArray Out (ArrayR shr (TupRpair tl tr)) sh (TupRpair bufl bufr)) k =
mkSOA (ArgArray Out (ArrayR shr tr ) sh bufr ) $ \soar ->
mkSOA (ArgArray Out (ArrayR shr tl ) sh bufl ) $ \soal ->
k (SOArgTup soar soal)
Expand Down

0 comments on commit 0467ac4

Please sign in to comment.