Skip to content

Commit

Permalink
Differentiate fold and fold1!
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsmeding committed Oct 12, 2020
1 parent 0240cfd commit 7aab1b5
Show file tree
Hide file tree
Showing 10 changed files with 354 additions and 55 deletions.
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Trafo/AD.hs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ convertAcc (OpenAcc (Reshape shr she a)) = OpenAcc (Reshape shr (convertExp she)
convertAcc (OpenAcc (Use rep a)) = OpenAcc (Use rep a)
convertAcc (OpenAcc (Fold f e a)) = OpenAcc (Fold (convertFun f) (convertExp <$> e) (convertAcc a))
convertAcc (OpenAcc (Scan dir f e a)) = OpenAcc (Scan dir (convertFun f) (convertExp <$> e) (convertAcc a))
convertAcc (OpenAcc (Scan' dir f e a)) = OpenAcc (Scan' dir (convertFun f) (convertExp e) (convertAcc a))
convertAcc (OpenAcc (ZipWith ty f a1 a2)) = OpenAcc (ZipWith ty (convertFun f) (convertAcc a1) (convertAcc a2))
convertAcc (OpenAcc (Permute f a1 fi a2)) = OpenAcc (Permute (convertFun f) (convertAcc a1) (convertFun fi) (convertAcc a2))
convertAcc (OpenAcc (Backpermute rep e f a)) = OpenAcc (Backpermute rep (convertExp e) (convertFun f) (convertAcc a))
Expand Down
258 changes: 238 additions & 20 deletions src/Data/Array/Accelerate/Trafo/AD/ADAcc.hs

Large diffs are not rendered by default.

24 changes: 3 additions & 21 deletions src/Data/Array/Accelerate/Trafo/AD/ADExp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ varsToArgs (TupRpair vars1 vars2) =
ex2 = varsToArgs vars2
in Pair (TupRpair (etypeOf ex1) (etypeOf ex2)) ex1 ex2

-- TODO: produceGradient should take the ExpVars value from BEFORE varsToArgs,
-- not after. That eliminates the error case if the argument is not
-- Nil/Pair/Arg.
produceGradient :: DMap (Idx args) (EDLabelT Int)
-> EContext Int env
-> OpenExp () aenv unused unused2 args t
Expand Down Expand Up @@ -961,27 +964,6 @@ dual' nodemap lbl (Context labelenv bindmap) contribmap =
(Let lhs adjoint)

expr -> trace ("\n!! " ++ show expr) undefined
where
smartPair :: OpenExp env aenv lab alab args a -> OpenExp env aenv lab alab args b -> OpenExp env aenv lab alab args (a, b)
smartPair a b = Pair (TupRpair (etypeOf a) (etypeOf b)) a b

smartNeg :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t
smartNeg ty a = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimNeg ty) a

smartRecip :: FloatingType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t
smartRecip ty a = PrimApp (TupRsingle (SingleScalarType (NumSingleType (FloatingNumType ty)))) (A.PrimRecip ty) a

smartSub :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t
smartSub ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimSub ty) (smartPair a b)

smartMul :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t
smartMul ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimMul ty) (smartPair a b)

smartFDiv :: FloatingType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t
smartFDiv ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType (FloatingNumType ty)))) (A.PrimFDiv ty) (smartPair a b)

smartGt :: SingleType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args A.PrimBool
smartGt ty a b = PrimApp (TupRsingle scalarType) (A.PrimGt ty) (smartPair a b)

-- TODO: make a new abstraction after the refactor, possibly inspired by this function, which was the abstraction pre-refactor
-- dualStoreAdjoint
Expand Down
32 changes: 30 additions & 2 deletions src/Data/Array/Accelerate/Trafo/AD/Acc.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ data OpenAcc aenv lab alab args t where
-> OpenAcc aenv lab alab args (Array sh t3)

Fold :: ArrayR (Array sh e)
-> ExpLambda1 aenv lab alab sh (e, e) e
-> Fun aenv lab alab ((e, e) -> e)
-> Maybe (Exp aenv lab alab () e)
-> OpenAcc aenv lab alab args (Array (sh, Int) e)
-> OpenAcc aenv lab alab args (Array sh e)
Expand All @@ -77,6 +77,20 @@ data OpenAcc aenv lab alab args t where
-> OpenAcc aenv lab alab args (Array (sh, Int) e)
-> OpenAcc aenv lab alab args (Array sh e)

Scan :: ArrayR (Array (sh, Int) e)
-> A.Direction
-> Fun aenv lab alab ((e, e) -> e)
-> Maybe (Exp aenv lab alab () e)
-> OpenAcc aenv lab alab args (Array (sh, Int) e)
-> OpenAcc aenv lab alab args (Array (sh, Int) e)

Scan' :: ArraysR (Array (sh, Int) e, Array sh e)
-> A.Direction
-> Fun aenv lab alab ((e, e) -> e)
-> Exp aenv lab alab () e
-> OpenAcc aenv lab alab args (Array (sh, Int) e)
-> OpenAcc aenv lab alab args (Array (sh, Int) e, Array sh e)

Generate :: ArrayR (Array sh e)
-> Exp aenv lab alab () sh
-> ExpLambda1 aenv lab alab sh sh e
Expand Down Expand Up @@ -213,9 +227,21 @@ showsAcc se d (ZipWith _ f e1 e2) =
showsAcc se d (Fold _ f me0 e) =
showParen (d > 10) $
showString (maybe "fold1 " (const "fold ") me0) .
showsLambda (se { seEnv = [] }) 11 f . showString " " .
showsFun (se { seEnv = [] }) 11 f . showString " " .
maybe id (\e0 -> showsExp (se { seEnv = [] }) 11 e0 . showString " ") me0 .
showsAcc se 11 e
showsAcc se d (Scan _ dir f me0 e) =
showParen (d > 10) $
showString ("scan" ++ (case dir of A.LeftToRight -> "l" ; A.RightToLeft -> "r") ++ maybe "1" (const "") me0 ++ " ") .
showsFun (se { seEnv = [] }) 11 f . showString " " .
maybe id (\e0 -> showsExp (se { seEnv = [] }) 11 e0 . showString " ") me0 .
showsAcc se 11 e
showsAcc se d (Scan' _ dir f e0 e) =
showParen (d > 10) $
showString ("scan" ++ (case dir of A.LeftToRight -> "l" ; A.RightToLeft -> "r") ++ "' ") .
showsFun (se { seEnv = [] }) 11 f . showString " " .
showsExp (se { seEnv = [] }) 11 e0 . showString " " .
showsAcc se 11 e
showsAcc se d (Backpermute _ dim f e) =
showParen (d > 10) $
showString "backpermute " .
Expand Down Expand Up @@ -304,6 +330,8 @@ atypeOf (Map ty _ _) = TupRsingle ty
atypeOf (ZipWith ty _ _ _) = TupRsingle ty
atypeOf (Generate ty _ _) = TupRsingle ty
atypeOf (Fold ty _ _ _) = TupRsingle ty
atypeOf (Scan ty _ _ _ _) = TupRsingle ty
atypeOf (Scan' ty _ _ _ _) = ty
atypeOf (Backpermute ty _ _ _) = TupRsingle ty
atypeOf (Permute ty _ _ _ _) = TupRsingle ty
atypeOf (Replicate ty _ _ _) = TupRsingle ty
Expand Down
25 changes: 25 additions & 0 deletions src/Data/Array/Accelerate/Trafo/AD/Exp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,28 @@ mkBool b
Const scalarType tag
| otherwise = error $ "Bool does not have a " ++ constrName ++ " constructor?"
where constrName = if b then "True" else "False"


smartPair :: OpenExp env aenv lab alab args a -> OpenExp env aenv lab alab args b -> OpenExp env aenv lab alab args (a, b)
smartPair a b = Pair (TupRpair (etypeOf a) (etypeOf b)) a b

smartNeg :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t
smartNeg ty a = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimNeg ty) a

smartRecip :: FloatingType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t
smartRecip ty a = PrimApp (TupRsingle (SingleScalarType (NumSingleType (FloatingNumType ty)))) (A.PrimRecip ty) a

smartAdd :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t
smartAdd ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimAdd ty) (smartPair a b)

smartSub :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t
smartSub ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimSub ty) (smartPair a b)

smartMul :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t
smartMul ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimMul ty) (smartPair a b)

smartFDiv :: FloatingType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t
smartFDiv ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType (FloatingNumType ty)))) (A.PrimFDiv ty) (smartPair a b)

smartGt :: SingleType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args A.PrimBool
smartGt ty a b = PrimApp (TupRsingle scalarType) (A.PrimGt ty) (smartPair a b)
15 changes: 14 additions & 1 deletion src/Data/Array/Accelerate/Trafo/AD/Pretty.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE GADTs #-}
module Data.Array.Accelerate.Trafo.AD.Pretty (prettyPrint) where

import qualified Data.Array.Accelerate.AST as A
import qualified Data.Array.Accelerate.AST.Var as A
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Trafo.AD.Acc
Expand Down Expand Up @@ -194,9 +195,21 @@ layoutAcc se d (ZipWith _ f e1 e2) =
layoutAcc se d (Fold _ f me0 e) =
parenthesise (d > 10) $
lprefix (maybe "fold1 " (const "fold ") me0)
(lseq' $ concat [[layoutLambda (se { seEnv = [] }) 11 f]
(lseq' $ concat [[layoutFun (se { seEnv = [] }) 11 f]
,maybe [] (\e0 -> [layoutExp (se { seEnv = [] }) 11 e0]) me0
,[layoutAcc se 11 e]])
layoutAcc se d (Scan _ dir f me0 e) =
parenthesise (d > 10) $
lprefix ("scan" ++ (case dir of A.LeftToRight -> "l" ; A.RightToLeft -> "r") ++ maybe "1" (const "") me0 ++ " ")
(lseq' $ concat [[layoutFun (se { seEnv = [] }) 11 f]
,maybe [] (\e0 -> [layoutExp (se { seEnv = [] }) 11 e0]) me0
,[layoutAcc se 11 e]])
layoutAcc se d (Scan' _ dir f e0 e) =
parenthesise (d > 10) $
lprefix ("scan" ++ (case dir of A.LeftToRight -> "l" ; A.RightToLeft -> "r") ++ "' ")
(lseq' [layoutFun (se { seEnv = [] }) 11 f
,layoutExp (se { seEnv = [] }) 11 e0
,layoutAcc se 11 e])
layoutAcc se d (Backpermute _ dim f e) =
parenthesise (d > 10) $
lprefix "backpermute "
Expand Down
17 changes: 12 additions & 5 deletions src/Data/Array/Accelerate/Trafo/AD/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,15 @@ goAcc = \case
Map ty lam a -> Map ty !$! simplifyLam1 lam !**! goAcc a
ZipWith ty lam a1 a2 -> ZipWith ty !$! simplifyLam1 lam !**! goAcc a1 !**! goAcc a2
Generate ty e lam -> Generate ty !$! goExp' e !**! simplifyLam1 lam
Fold ty lam me0 a -> Fold ty !$! simplifyLam1 lam
!**! (case me0 of Just e0 -> Just !$! goExp' e0
Nothing -> returnS Nothing)
!**! goAcc a
Fold ty f me0 a -> Fold ty !$! simplifyFun f
!**! (case me0 of Just e0 -> Just !$! goExp' e0
Nothing -> returnS Nothing)
!**! goAcc a
Scan ty dir f me0 a -> Scan ty dir !$! simplifyFun f
!**! (case me0 of Just e0 -> Just !$! goExp' e0
Nothing -> returnS Nothing)
!**! goAcc a
Scan' ty dir f e0 a -> Scan' ty dir !$! simplifyFun f !**! goExp' e0 !**! goAcc a
Sum ty a -> Sum ty !$! goAcc a
Replicate ty slt she a -> Replicate ty slt !$! goExp' she !**! goAcc a
Slice ty slt a she -> Slice ty slt !$! goAcc a !**! goExp' she
Expand Down Expand Up @@ -206,7 +211,9 @@ inlineA f = \case
Map ty lam a -> Map ty (inlineALam f lam) (inlineA f a)
ZipWith ty lam a1 a2 -> ZipWith ty (inlineALam f lam) (inlineA f a1) (inlineA f a2)
Generate ty e lam -> Generate ty (inlineAE f e) (inlineALam f lam)
Fold ty lam me0 a -> Fold ty (inlineALam f lam) (inlineAE f <$> me0) (inlineA f a)
Fold ty fun me0 a -> Fold ty (inlineAEF f fun) (inlineAE f <$> me0) (inlineA f a)
Scan ty dir fun me0 a -> Scan ty dir (inlineAEF f fun) (inlineAE f <$> me0) (inlineA f a)
Scan' ty dir fun e0 a -> Scan' ty dir (inlineAEF f fun) (inlineAE f e0) (inlineA f a)
Sum ty a -> Sum ty (inlineA f a)
Replicate ty slt she a -> Replicate ty slt (inlineAE f she) (inlineA f a)
Slice ty slt a she -> Slice ty slt (inlineA f a) (inlineAE f she)
Expand Down
8 changes: 6 additions & 2 deletions src/Data/Array/Accelerate/Trafo/AD/Sink.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ sinkAcc _ Anil = Anil
sinkAcc k (Acond ty c t e) = Acond ty (sinkExpAenv k c) (sinkAcc k t) (sinkAcc k e)
sinkAcc k (Map ty f e) = Map ty (sinkFunAenv k <$> f) (sinkAcc k e)
sinkAcc k (ZipWith ty f e1 e2) = ZipWith ty (sinkFunAenv k <$> f) (sinkAcc k e1) (sinkAcc k e2)
sinkAcc k (Fold ty f me0 e) = Fold ty (sinkFunAenv k <$> f) (sinkExpAenv k <$> me0) (sinkAcc k e)
sinkAcc k (Fold ty f me0 e) = Fold ty (sinkFunAenv k f) (sinkExpAenv k <$> me0) (sinkAcc k e)
sinkAcc k (Scan ty dir f me0 e) = Scan ty dir (sinkFunAenv k f) (sinkExpAenv k <$> me0) (sinkAcc k e)
sinkAcc k (Scan' ty dir f e0 e) = Scan' ty dir (sinkFunAenv k f) (sinkExpAenv k e0) (sinkAcc k e)
sinkAcc k (Backpermute ty dim f e) = Backpermute ty (sinkExpAenv k dim) (sinkFunAenv k f) (sinkAcc k e)
sinkAcc k (Permute ty comb def pf e) = Permute ty (sinkFunAenv k comb) (sinkAcc k def) (sinkFunAenv k pf) (sinkAcc k e)
sinkAcc k (Sum ty e) = Sum ty (sinkAcc k e)
Expand Down Expand Up @@ -157,7 +159,9 @@ aCheckClosedInTagval tv expr = case expr of
Acond ty c t e -> Acond ty <$> eCheckAClosedInTagval tv c <*> aCheckClosedInTagval tv t <*> aCheckClosedInTagval tv e
Map ty f e -> Map ty <$> traverse (efCheckAClosedInTagval tv) f <*> aCheckClosedInTagval tv e
ZipWith ty f e1 e2 -> ZipWith ty <$> traverse (efCheckAClosedInTagval tv) f <*> aCheckClosedInTagval tv e1 <*> aCheckClosedInTagval tv e2
Fold ty f me0 e -> Fold ty <$> traverse (efCheckAClosedInTagval tv) f <*> traverse (eCheckAClosedInTagval tv) me0 <*> aCheckClosedInTagval tv e
Fold ty f me0 e -> Fold ty <$> efCheckAClosedInTagval tv f <*> traverse (eCheckAClosedInTagval tv) me0 <*> aCheckClosedInTagval tv e
Scan ty dir f me0 e -> Scan ty dir <$> efCheckAClosedInTagval tv f <*> traverse (eCheckAClosedInTagval tv) me0 <*> aCheckClosedInTagval tv e
Scan' ty dir f e0 e -> Scan' ty dir <$> efCheckAClosedInTagval tv f <*> eCheckAClosedInTagval tv e0 <*> aCheckClosedInTagval tv e
Backpermute ty dim f e -> Backpermute ty <$> eCheckAClosedInTagval tv dim <*> efCheckAClosedInTagval tv f <*> aCheckClosedInTagval tv e
Permute ty cf def pf e -> Permute ty <$> efCheckAClosedInTagval tv cf <*> aCheckClosedInTagval tv def <*> efCheckAClosedInTagval tv pf <*> aCheckClosedInTagval tv e
Sum ty e -> Sum ty <$> aCheckClosedInTagval tv e
Expand Down
11 changes: 8 additions & 3 deletions src/Data/Array/Accelerate/Trafo/AD/Translate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ translateAcc (A.OpenAcc expr) = case expr of
| Nothing <- initval ->
D.Sum (A.arrayR expr) (translateAcc e)
A.Fold f me0 e ->
D.Fold (A.arrayR expr) (Right $ toPairedBinop $ translateFun f) (translateExp <$> me0) (translateAcc e)
D.Fold (A.arrayR expr) (toPairedBinop $ translateFun f) (translateExp <$> me0) (translateAcc e)
A.Scan dir f me0 e ->
D.Scan (A.arrayR expr) dir (toPairedBinop $ translateFun f) (translateExp <$> me0) (translateAcc e)
A.Scan' dir f e0 e ->
D.Scan' (A.arraysR expr) dir (toPairedBinop $ translateFun f) (translateExp e0) (translateAcc e)
A.Generate ty she f ->
D.Generate ty (translateExp she) (Right $ translateFun f)
A.Replicate slt sle e ->
Expand Down Expand Up @@ -228,7 +232,9 @@ untranslateLHSboundAcc toplhs topexpr
D.Acond _ e1 e2 e3 -> A.Acond (untranslateClosedExpA e1 pv) (go e2 pv) (go e3 pv)
D.Map (ArrayR _ ty) (Right f) e -> A.Map ty (untranslateClosedFunA f pv) (go e pv)
D.ZipWith (ArrayR _ ty) (Right f) e1 e2 -> A.ZipWith ty (untranslateClosedFunA (fromPairedBinop f) pv) (go e1 pv) (go e2 pv)
D.Fold _ (Right f) me0 e -> A.Fold (untranslateClosedFunA (fromPairedBinop f) pv) (untranslateClosedExpA <$> me0 <*> Just pv) (go e pv)
D.Fold _ f me0 e -> A.Fold (untranslateClosedFunA (fromPairedBinop f) pv) (untranslateClosedExpA <$> me0 <*> Just pv) (go e pv)
D.Scan _ dir f me0 e -> A.Scan dir (untranslateClosedFunA (fromPairedBinop f) pv) (untranslateClosedExpA <$> me0 <*> Just pv) (go e pv)
D.Scan' _ dir f e0 e -> A.Scan' dir (untranslateClosedFunA (fromPairedBinop f) pv) (untranslateClosedExpA e0 pv) (go e pv)
D.Sum (ArrayR _ (TupRsingle ty@(SingleScalarType (NumSingleType nt)))) e ->
A.Fold (A.Lam (A.LeftHandSideSingle ty) (A.Lam (A.LeftHandSideSingle ty)
(A.Body (A.PrimApp (A.PrimAdd nt)
Expand Down Expand Up @@ -261,7 +267,6 @@ untranslateLHSboundAcc toplhs topexpr
D.Alabel _ -> internalError "AD.untranslateLHSboundAcc: Unexpected Label in untranslate!"
D.Map _ _ _ -> error "Unexpected Map shape in untranslate"
D.ZipWith _ _ _ _ -> error "Unexpected ZipWith shape in untranslate"
D.Fold _ _ _ _ -> error "Unexpected Fold shape in untranslate"
D.Sum _ _ -> error "Unexpected Sum shape in untranslate"
D.Generate _ _ _ -> error "Unexpected Generate shape in untranslate"
D.Reduce _ _ _ _ -> error "Unexpected Reduce shape in untranslate"
Expand Down
18 changes: 17 additions & 1 deletion test/tom/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,21 @@ arrad = do
A.gradientA (\arr -> A.sum $ A.zipWith (*) arr (A.generate (A.shape arr) (\(A.I1 i) -> A.cond (i A.< 5) 0 1)))
(A.use (A.fromList (Z :. 10) [1 :: Float .. 10]))

adfold :: IO ()
adfold = do
-- print $
-- A.gradientA (\arr -> A.maximum arr)
-- (A.use (A.fromList (Z :. 10) [1 :: Float .. 10]))
let input = A.fromList (Z :. 8) [1 :: Float .. 8]

AD.aCompareAD (\arr -> A.maximum arr) input
AD.aCompareAD (\arr -> A.product arr) input

print $
A.gradientA (\arr -> let p = A.product arr
in A.zipWith (+) (A.map (*2) p) (A.map (*3) p))
(A.use input)

main :: IO ()
main = do
-- logistic
Expand All @@ -300,7 +315,8 @@ main = do
-- adtuple1
-- adtuple2
-- adtuple3
arrad
-- arrad
adfold
-- neural
-- neural2
-- Playground.Neural.main
Expand Down

0 comments on commit 7aab1b5

Please sign in to comment.