From 7aab1b588b31a9c1fd9595340c60b5a8be062ce9 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 12 Oct 2020 21:41:36 +0200 Subject: [PATCH] Differentiate fold and fold1! --- src/Data/Array/Accelerate/Trafo/AD.hs | 1 + src/Data/Array/Accelerate/Trafo/AD/ADAcc.hs | 258 ++++++++++++++++-- src/Data/Array/Accelerate/Trafo/AD/ADExp.hs | 24 +- src/Data/Array/Accelerate/Trafo/AD/Acc.hs | 32 ++- src/Data/Array/Accelerate/Trafo/AD/Exp.hs | 25 ++ src/Data/Array/Accelerate/Trafo/AD/Pretty.hs | 15 +- .../Array/Accelerate/Trafo/AD/Simplify.hs | 17 +- src/Data/Array/Accelerate/Trafo/AD/Sink.hs | 8 +- .../Array/Accelerate/Trafo/AD/Translate.hs | 11 +- test/tom/Main.hs | 18 +- 10 files changed, 354 insertions(+), 55 deletions(-) diff --git a/src/Data/Array/Accelerate/Trafo/AD.hs b/src/Data/Array/Accelerate/Trafo/AD.hs index 6ff920bae..1663fcca7 100644 --- a/src/Data/Array/Accelerate/Trafo/AD.hs +++ b/src/Data/Array/Accelerate/Trafo/AD.hs @@ -77,6 +77,7 @@ convertAcc (OpenAcc (Reshape shr she a)) = OpenAcc (Reshape shr (convertExp she) convertAcc (OpenAcc (Use rep a)) = OpenAcc (Use rep a) convertAcc (OpenAcc (Fold f e a)) = OpenAcc (Fold (convertFun f) (convertExp <$> e) (convertAcc a)) convertAcc (OpenAcc (Scan dir f e a)) = OpenAcc (Scan dir (convertFun f) (convertExp <$> e) (convertAcc a)) +convertAcc (OpenAcc (Scan' dir f e a)) = OpenAcc (Scan' dir (convertFun f) (convertExp e) (convertAcc a)) convertAcc (OpenAcc (ZipWith ty f a1 a2)) = OpenAcc (ZipWith ty (convertFun f) (convertAcc a1) (convertAcc a2)) convertAcc (OpenAcc (Permute f a1 fi a2)) = OpenAcc (Permute (convertFun f) (convertAcc a1) (convertFun fi) (convertAcc a2)) convertAcc (OpenAcc (Backpermute rep e f a)) = OpenAcc (Backpermute rep (convertExp e) (convertFun f) (convertAcc a)) diff --git a/src/Data/Array/Accelerate/Trafo/AD/ADAcc.hs b/src/Data/Array/Accelerate/Trafo/AD/ADAcc.hs index f2c922042..4d2f8f8e4 100644 --- a/src/Data/Array/Accelerate/Trafo/AD/ADAcc.hs +++ b/src/Data/Array/Accelerate/Trafo/AD/ADAcc.hs @@ -42,13 +42,14 @@ import Data.Array.Accelerate.Representation.Type import Data.Array.Accelerate.Trafo.AD.Acc import Data.Array.Accelerate.Trafo.AD.Additive import Data.Array.Accelerate.Trafo.AD.ADExp (splitLambdaAD, labeliseExp, labeliseFun, inlineAvarLabels') +import qualified Data.Array.Accelerate.Trafo.AD.ADExp as ADExp import Data.Array.Accelerate.Trafo.AD.Algorithms import Data.Array.Accelerate.Trafo.AD.Common import Data.Array.Accelerate.Trafo.AD.Debug import Data.Array.Accelerate.Trafo.AD.Exp import Data.Array.Accelerate.Trafo.AD.Pretty import Data.Array.Accelerate.Trafo.AD.Sink -import Data.Array.Accelerate.Trafo.Substitution (rebuildLHS, weakenVars) +import Data.Array.Accelerate.Trafo.Substitution (rebuildLHS, weakenVars, weaken) import Data.Array.Accelerate.Trafo.Var (declareVars, DeclareVars(..)) @@ -100,6 +101,8 @@ generaliseArgs (Acond ty e1 e2 e3) = Acond ty e1 (generaliseArgs e2) (generalise generaliseArgs (Map ty f e) = Map ty f (generaliseArgs e) generaliseArgs (ZipWith ty f e1 e2) = ZipWith ty f (generaliseArgs e1) (generaliseArgs e2) generaliseArgs (Fold ty f me0 a) = Fold ty f me0 (generaliseArgs a) +generaliseArgs (Scan ty dir f me0 a) = Scan ty dir f me0 (generaliseArgs a) +generaliseArgs (Scan' ty dir f e0 a) = Scan' ty dir f e0 (generaliseArgs a) generaliseArgs (Backpermute ty dim f e) = Backpermute ty dim f (generaliseArgs e) generaliseArgs (Replicate ty sht she e) = Replicate ty sht she (generaliseArgs e) generaliseArgs (Slice ty sht e she) = Slice ty sht (generaliseArgs e) she @@ -163,6 +166,9 @@ reverseADA paramlhs expr ex2 = varsToArgs vars2 in Apair (TupRpair (atypeOf ex1) (atypeOf ex2)) ex1 ex2 + -- TODO: produceGradient should take the ArrayVars value from BEFORE + -- varsToArgs, not after. That eliminates the error case if the argument is + -- not Nil/Pair/Arg. produceGradient :: DMap (Idx args) (ADLabelT Int) -> AContext Int aenv -> OpenAcc () unused1 unused2 args t @@ -194,7 +200,9 @@ realiseArgs = \expr lhs -> go A.weakenId (A.weakenWithLHS lhs) expr Acond ty e1 e2 e3 -> Acond ty (sinkExpAenv varWeaken e1) (go argWeaken varWeaken e2) (go argWeaken varWeaken e3) Map ty f e -> Map ty (sinkFunAenv varWeaken <$> f) (go argWeaken varWeaken e) ZipWith ty f e1 e2 -> ZipWith ty (sinkFunAenv varWeaken <$> f) (go argWeaken varWeaken e1) (go argWeaken varWeaken e2) - Fold ty f me0 e -> Fold ty (sinkFunAenv varWeaken <$> f) (sinkExpAenv varWeaken <$> me0) (go argWeaken varWeaken e) + Fold ty f me0 e -> Fold ty (sinkFunAenv varWeaken f) (sinkExpAenv varWeaken <$> me0) (go argWeaken varWeaken e) + Scan ty dir f me0 e -> Scan ty dir (sinkFunAenv varWeaken f) (sinkExpAenv varWeaken <$> me0) (go argWeaken varWeaken e) + Scan' ty dir f e0 e -> Scan' ty dir (sinkFunAenv varWeaken f) (sinkExpAenv varWeaken e0) (go argWeaken varWeaken e) Backpermute ty dim f e -> Backpermute ty (sinkExpAenv varWeaken dim) (sinkFunAenv varWeaken f) (go argWeaken varWeaken e) Permute ty cf def pf e -> Permute ty (sinkFunAenv varWeaken cf) (go argWeaken varWeaken def) (sinkFunAenv varWeaken pf) (go argWeaken varWeaken e) Sum ty e -> Sum ty (go argWeaken varWeaken e) @@ -269,16 +277,18 @@ explode' labelenv = \case argmp = DMap.unionsWithKey (error "explode: Overlapping arg's") [argmp1, argmp2] return (lab, mp, argmp) ZipWith _ (Left _) _ _ -> error "explode: Unexpected ZipWith SplitLambdaAD" - Fold ty@(ArrayR sht _) (Right e) me0 a -> do - e' <- splitLambdaAD labelenv (genSingleId . ArrayR sht) (generaliseLabFun e) - let me0' = snd . labeliseExp labelenv . generaliseLabA . generaliseLab <$> me0 + Fold ty e me0 a -> do + -- TODO: This does NOT split the lambda in Fold. This is because + -- currently, we do recompute-all for the Fold lambda, not store-all, + -- since where do we even store the temporaries? + let e' = snd . labeliseFun labelenv . generaliseLabFunA . generaliseLabFun $ e + me0' = snd . labeliseExp labelenv . generaliseLabA . generaliseLab <$> me0 (lab1, mp1, argmp1) <- explode' labelenv a lab <- genId (TupRsingle ty) - let pruned = Fold ty (Left e') me0' (Alabel lab1) + let pruned = Fold ty e' me0' (Alabel lab1) let itemmp = DMap.singleton lab pruned mp = DMap.unionWithKey (error "explode: Overlapping id's") mp1 itemmp return (lab, mp, argmp1) - Fold _ (Left _) _ _ -> error "explode: Unexpected Fold SplitLambdaAD" Backpermute ty dim f a -> do let f' = snd . labeliseFun labelenv . generaliseLabFunA . generaliseLabFun $ f dim' = snd . labeliseExp labelenv . generaliseLabA . generaliseLab $ dim @@ -344,6 +354,8 @@ explode' labelenv = \case Alabel _ -> error "explode: Unexpected Alabel" Reduce _ _ _ _ -> error "explode: Unexpected Reduce, should only be created in dual" Permute _ _ _ _ _ -> error "explode: Unexpected Permute, can't do AD on Permute yet" + Scan _ _ _ _ _ -> error "explode: Unexpected Scan, can't do AD on Scan yet" + Scan' _ _ _ _ _ -> error "explode: Unexpected Scan', can't do AD on Scan' yet" where lpushLHS_Get :: A.ALeftHandSide t aenv aenv' -> TupR (ADLabel Int) t -> ALabVal Int aenv -> Acc lab Int args t -> (ALabVal Int aenv', DMap (ADLabelT Int) (Acc lab Int args)) lpushLHS_Get lhs labs labelenv' rhs = case (lhs, labs) of @@ -484,7 +496,7 @@ primal' nodemap (AnyLabel lbl : restlabels) (Context labelenv bindmap) cont (Alet (A.LeftHandSideSingle pairArrType) (Map pairArrType (Right (fmapAlabFun (fmapLabel P) (lambdaPrimal lambdaVars))) (Avar (A.Var argtypeS argidx))) - (smartPair + (smartPairA (Map restype (Right (expFstLam pairEltType)) (Avar (A.Var pairArrType ZeroIdx))) (Map tmpArrType (Right (expSndLam pairEltType)) (Avar (A.Var pairArrType ZeroIdx))))) <$> primal' nodemap restlabels @@ -509,7 +521,7 @@ primal' nodemap (AnyLabel lbl : restlabels) (Context labelenv bindmap) cont (ZipWith pairArrType (Right (fmapAlabFun (fmapLabel P) (lambdaPrimal lambdaVars))) (Avar (A.Var (labelType arglab1S) argidx1)) (Avar (A.Var (labelType arglab2S) argidx2))) - (smartPair + (smartPairA (Map restype (Right (expFstLam pairEltType)) (Avar (A.Var pairArrType ZeroIdx))) (Map tmpArrType (Right (expSndLam pairEltType)) (Avar (A.Var pairArrType ZeroIdx))))) <$> primal' nodemap restlabels @@ -519,13 +531,13 @@ primal' nodemap (AnyLabel lbl : restlabels) (Context labelenv bindmap) cont _ -> error "primal: ZipWith arguments did not compute arguments" - Fold restype (Left lambda) e0expr (Alabel arglab) -> + Fold restype combfun e0expr (Alabel arglab) -> let TupRsingle arglabS@(DLabel argtype _) = bindmap `dmapFind` fmapLabel P arglab in case alabValFind labelenv arglabS of Just argidx -> do lab <- genSingleId restype Alet (A.LeftHandSideSingle restype) - (Fold restype (Left (fmapAlabSplitLambdaAD (fmapLabel P) lambda)) + (Fold restype (resolveAlabsFun (Context labelenv bindmap) combfun) (resolveAlabs (Context labelenv bindmap) <$> e0expr) (Avar (A.Var argtype argidx))) <$> primal' nodemap restlabels @@ -562,7 +574,7 @@ primal' nodemap (AnyLabel lbl : restlabels) (Context labelenv bindmap) cont (Alet (A.LeftHandSideSingle pairArrType) (Generate pairArrType (resolveAlabs (Context labelenv bindmap) shexp) (Right (fmapAlabFun (fmapLabel P) (lambdaPrimal lambdaVars)))) - (smartPair + (smartPairA (Map restype (Right (expFstLam pairEltType)) (Avar (A.Var pairArrType ZeroIdx))) (Map tmpArrType (Right (expSndLam pairEltType)) (Avar (A.Var pairArrType ZeroIdx))))) <$> primal' nodemap restlabels @@ -805,6 +817,112 @@ dual' nodemap (AnyLabel lbl : restlabels) (Context labelenv bindmap) contribmap (DMap.insert (fmapLabel D lbl) labs bindmap)) contribmap' cont + -- TODO: This does not contribute any derivative to the initial expression, since we don't support array indexing yet. + Fold restype@(ArrayR resshape _) origf@(Lam lambdalhs (Body lambdabody)) (Just initexp) (Alabel arglab) + | TupRsingle (SingleScalarType (NumSingleType elttypeN@(FloatingNumType TypeFloat))) <- etypeOf lambdabody + , ReplicateOneMore onemoreSlix onemoreExpf <- replicateOneMore resshape -> do + let argtype = labelType arglab + TupRsingle argtypeS = argtype + adjoint = collectAdjoint contribmap lbl (Context labelenv bindmap) + templab <- genSingleId argtypeS + let contribmap' = updateContribmap lbl + [Contribution arglab TLNil (TupRsingle templab :@ TLNil) $ + \(TupRsingle adjvar) _ (TupRsingle tempvar :@ TLNil) _ -> + -- zipWith (*) (replicate (shape a) adjoint) (usual_derivative) + smartZipWith (timesLam elttypeN) + (Replicate argtypeS onemoreSlix (onemoreExpf (Shape (Left tempvar))) (Avar adjvar)) + (Avar tempvar)] + contribmap + -- technically don't need the tuple machinery here, but for consistency + (Exists lhs, labs) <- genSingleIds (TupRsingle restype) + -- Compute the derivative here once, so that it only needs to be + -- combined with the local adjoint on every use, instead of having to + -- recompute this whole thing. + let labelenv' = lpushLabTup labelenv lhs labs + case alabValFinds labelenv' (bindmap `dmapFind` fmapLabel P arglab) of + Just (TupRsingle argvar) -> + Alet lhs adjoint + . Alet (A.LeftHandSideSingle (labelType templab)) + (case ADExp.reverseAD lambdalhs (resolveAlabs (Context labelenv' bindmap) lambdabody) of + ADExp.ReverseADResE lambdalhs' dualbody -> + -- let sc = init (scanl f x0 a) + -- in zipWith (*) (zipWith D₂f sc a) + -- (tail (scanr (*) 1 (zipWith D₁f sc a))) + let d1f = Lam lambdalhs' (Body (Get (etypeOf lambdabody) (TILeft TIHere) dualbody)) + d2f = Lam lambdalhs' (Body (Get (etypeOf lambdabody) (TIRight TIHere) dualbody)) + weaken1 = A.weakenSucc' A.weakenId + argvar' = weaken weaken1 argvar + (d1f', d2f') = (sinkFunAenv weaken1 d1f, sinkFunAenv weaken1 d2f) + initScan ty dir f e0 a = -- init (scanl) / tail (scanr) + let scan'type = let ArrayR (ShapeRsnoc shtype) elttype = ty + in TupRpair (TupRsingle ty) (TupRsingle (ArrayR shtype elttype)) + in Aget (TupRsingle ty) (TILeft TIHere) (Scan' scan'type dir f e0 a) + in Alet (A.LeftHandSideSingle argtypeS) + (initScan argtypeS A.LeftToRight + (resolveAlabsFun (Context labelenv' bindmap) origf) + (resolveAlabs (Context labelenv' bindmap) initexp) + (Avar argvar)) + (smartZipWith (timesLam elttypeN) + (smartZipWith d2f' (Avar (A.Var argtypeS ZeroIdx)) (Avar argvar')) + (initScan argtypeS A.RightToLeft (timesLam elttypeN) + (zeroForType' 1 elttypeN) + (smartZipWith d1f' (Avar (A.Var argtypeS ZeroIdx)) (Avar argvar'))))) + <$> dual' nodemap restlabels (Context (LPush labelenv' templab) + (DMap.insert (fmapLabel D lbl) labs bindmap)) + contribmap' cont + _ -> error $ "dual Fold: argument primal was not computed" + + Fold restype@(ArrayR resshape _) origf@(Lam lambdalhs (Body lambdabody)) Nothing (Alabel arglab) + | TupRsingle (SingleScalarType (NumSingleType elttypeN@(FloatingNumType TypeFloat))) <- etypeOf lambdabody + , ReplicateOneMore onemoreSlix onemoreExpf <- replicateOneMore resshape -> do + let argtype = labelType arglab + TupRsingle argtypeS = argtype + adjoint = collectAdjoint contribmap lbl (Context labelenv bindmap) + templab <- genSingleId argtypeS + let contribmap' = updateContribmap lbl + [Contribution arglab TLNil (TupRsingle templab :@ TLNil) $ + \(TupRsingle adjvar) _ (TupRsingle tempvar :@ TLNil) _ -> + -- zipWith (*) (replicate (shape a) adjoint) (usual_derivative) + smartZipWith (timesLam elttypeN) + (Replicate argtypeS onemoreSlix (onemoreExpf (Shape (Left tempvar))) (Avar adjvar)) + (Avar tempvar)] + contribmap + -- technically don't need the tuple machinery here, but for consistency + (Exists lhs, labs) <- genSingleIds (TupRsingle restype) + -- Compute the derivative here once, so that it only needs to be + -- combined with the local adjoint on every use, instead of having to + -- recompute this whole thing. + let labelenv' = lpushLabTup labelenv lhs labs + case alabValFinds labelenv' (bindmap `dmapFind` fmapLabel P arglab) of + Just (TupRsingle argvar) -> + Alet lhs adjoint + . Alet (A.LeftHandSideSingle (labelType templab)) + (case ADExp.reverseAD lambdalhs (resolveAlabs (Context labelenv' bindmap) lambdabody) of + ADExp.ReverseADResE lambdalhs' dualbody -> + -- let sc = init (scanl1 f a) + -- in zipWith (*) ([1] ++ zipWith D₂f sc (tail l)) + -- (scanr (*) 1 (zipWith D₁f sc (tail l))) + let d1f = Lam lambdalhs' (Body (Get (etypeOf lambdabody) (TILeft TIHere) dualbody)) + d2f = Lam lambdalhs' (Body (Get (etypeOf lambdabody) (TIRight TIHere) dualbody)) + weaken1 = A.weakenSucc' A.weakenId + argvar' = weaken weaken1 argvar + (d1f', d2f') = (sinkFunAenv weaken1 d1f, sinkFunAenv weaken1 d2f) + in Alet (A.LeftHandSideSingle argtypeS) + (smartInit (Scan argtypeS A.LeftToRight + (resolveAlabsFun (Context labelenv' bindmap) origf) + Nothing + (Avar argvar))) + (smartZipWith (timesLam elttypeN) + (smartCons (zeroForType' 1 elttypeN) + (smartZipWith d2f' (Avar (A.Var argtypeS ZeroIdx)) (smartTail (Avar argvar')))) + (Scan argtypeS A.RightToLeft (timesLam elttypeN) + (Just (zeroForType' 1 elttypeN)) + (smartZipWith d1f' (Avar (A.Var argtypeS ZeroIdx)) (smartTail (Avar argvar')))))) + <$> dual' nodemap restlabels (Context (LPush labelenv' templab) + (DMap.insert (fmapLabel D lbl) labs bindmap)) + contribmap' cont + _ -> error $ "dual Fold: argument primal was not computed" + -- TODO: Since we don't support array indexing yet, Generate nodes have -- no array arguments they can contribute anything to. Thus, we ignore -- them in the dual pass. @@ -924,6 +1042,77 @@ dual' nodemap (AnyLabel lbl : restlabels) (Context labelenv bindmap) contribmap contribmap' cont expr -> trace ("\n!! " ++ show expr) undefined + where + smartZipWith :: Fun aenv lab alab ((e1, e2) -> e) + -> OpenAcc aenv lab alab args (Array sh e1) + -> OpenAcc aenv lab alab args (Array sh e2) + -> OpenAcc aenv lab alab args (Array sh e) + smartZipWith f@(Lam _ (Body body)) a1 a2 = + let TupRsingle (ArrayR shtype _) = atypeOf a1 + in ZipWith (ArrayR shtype (etypeOf body)) (Right f) a1 a2 + smartZipWith _ _ _ = error "impossible GADTs" + + smartInnerPermute :: (forall env aenv'. OpenExp env aenv' lab alab () Int + -> OpenExp env aenv' lab alab () Int) -- ^ new inner dimension size + -> (forall env aenv'. OpenExp env aenv' lab alab () Int + -> OpenExp env aenv' lab alab () Int) -- ^ inner index transformer + -> OpenAcc aenv lab alab args (Array (sh, Int) t) + -> OpenAcc aenv lab alab args (Array (sh, Int) t) + smartInnerPermute sizeExpr indexExpr a + | TupRsingle ty@(ArrayR shtype _) <- atypeOf a + , TupRpair tailsht _ <- shapeType shtype + , LetBoundExpE shlhs shvars <- elhsCopy tailsht = + Alet (A.LeftHandSideSingle ty) a + (Backpermute ty + (Let (A.LeftHandSidePair shlhs (A.LeftHandSideSingle scalarType)) + (Shape (Left (A.Var ty ZeroIdx))) + (smartPair + (evars (weakenVars (A.weakenSucc A.weakenId) shvars)) + (sizeExpr (Var (A.Var scalarType ZeroIdx))))) + (Lam (A.LeftHandSidePair shlhs (A.LeftHandSideSingle scalarType)) + (Body (smartPair + (evars (weakenVars (A.weakenSucc A.weakenId) shvars)) + (indexExpr (Var (A.Var scalarType ZeroIdx)))))) + (Avar (A.Var ty ZeroIdx))) + smartInnerPermute _ _ _ = error "impossible GADTs" + + smartTail :: OpenAcc aenv lab alab args (Array (sh, Int) t) -> OpenAcc aenv lab alab args (Array (sh, Int) t) + smartTail = smartInnerPermute (\sz -> smartSub numType sz (Const scalarType 1)) + (\idx -> smartAdd numType idx (Const scalarType 1)) + + smartInit :: OpenAcc aenv lab alab args (Array (sh, Int) t) -> OpenAcc aenv lab alab args (Array (sh, Int) t) + smartInit = smartInnerPermute (\sz -> smartSub numType sz (Const scalarType 1)) + (\idx -> idx) + + smartCons :: (forall env aenv'. OpenExp env aenv' lab alab () t) + -> OpenAcc aenv lab alab args (Array (sh, Int) t) + -> OpenAcc aenv lab alab args (Array (sh, Int) t) + smartCons prefix a + | TupRsingle ty@(ArrayR shtype elttype) <- atypeOf a + , TupRpair tailsht _ <- shapeType shtype + , LetBoundExpE shlhs shvars <- elhsCopy tailsht = + Alet (A.LeftHandSideSingle ty) a + (Generate ty + (Let (A.LeftHandSidePair shlhs (A.LeftHandSideSingle scalarType)) + (Shape (Left (A.Var ty ZeroIdx))) + (smartPair + (evars (weakenVars (A.weakenSucc A.weakenId) shvars)) + (smartAdd numType (Var (A.Var scalarType ZeroIdx)) (Const scalarType 1)))) + (Right (Lam (A.LeftHandSidePair shlhs (A.LeftHandSideSingle scalarType)) + (Body (Cond elttype + (smartGt singleType (Var (A.Var scalarType ZeroIdx)) (Const scalarType 0)) + (Index (Left (A.Var ty ZeroIdx)) + (smartPair + (evars (weakenVars (A.weakenSucc A.weakenId) shvars)) + (smartSub numType (Var (A.Var scalarType ZeroIdx)) (Const scalarType 1)))) + prefix))))) + smartCons _ _ = error "impossible GADTs" + + timesLam :: NumType t -> Fun aenv lab alab ((t, t) -> t) + timesLam ty = + let sty = SingleScalarType (NumSingleType ty) + in Lam (A.LeftHandSidePair (A.LeftHandSideSingle sty) (A.LeftHandSideSingle sty)) + (Body (smartMul ty (Var (A.Var sty (SuccIdx ZeroIdx))) (Var (A.Var sty ZeroIdx)))) -- Utility functions -- ----------------- @@ -1077,6 +1266,36 @@ reduceSpecFromReplicate SliceNil = RSpecNil reduceSpecFromReplicate (SliceAll slix) = RSpecKeep (reduceSpecFromReplicate slix) reduceSpecFromReplicate (SliceFixed slix) = RSpecReduce (reduceSpecFromReplicate slix) +data ReplicateOneMore sh = + forall slix. + ReplicateOneMore (SliceIndex slix sh ((), Int) (sh, Int)) + (forall env aenv lab alab args. OpenExp env aenv lab alab args (sh, Int) + -> OpenExp env aenv lab alab args slix) + +-- Produces a SliceIndex that can be passed to Replicate, and a function that +-- produces the slix expression parameter to Replicate, given an expression for +-- the desired full shape. +replicateOneMore :: ShapeR sh -> ReplicateOneMore sh +replicateOneMore sh + | SliceCopy slix e <- sliceCopy sh + = ReplicateOneMore (SliceFixed slix) + (\she -> Let (A.LeftHandSidePair (A.LeftHandSideWildcard (shapeType (sliceShapeR slix))) + (A.LeftHandSideSingle scalarType)) + she + (smartPair e (Var (A.Var scalarType ZeroIdx)))) + +data SliceCopy sh = + forall slix. + SliceCopy (SliceIndex slix sh () sh) + (forall env aenv lab alab args. OpenExp env aenv lab alab args slix) + +sliceCopy :: ShapeR sh -> SliceCopy sh +sliceCopy ShapeRz = SliceCopy SliceNil Nil +sliceCopy (ShapeRsnoc sh) + | SliceCopy slix e <- sliceCopy sh + = SliceCopy (SliceAll slix) (smartPair e Nil) + + -- The dual of Slice is a Generate that picks the adjoint for the entries -- sliced, and zero for the entries cut away. This is the lambda for that -- Generate. @@ -1106,8 +1325,8 @@ sliceDualLambda slix adjvar@(A.Var (ArrayR _ eltty) _) slexpr genCond (SliceAll slix') (TupRpair idxvs _) (TupRpair slvs _) = genCond slix' idxvs slvs genCond (SliceFixed slix') (TupRpair idxvs (TupRsingle idxv)) (TupRpair slvs (TupRsingle slv)) = PrimApp (TupRsingle scalarType) A.PrimLAnd - (smartPairE (PrimApp (TupRsingle scalarType) (A.PrimEq singleType) - (smartPairE (Var idxv) (Var slv))) + (smartPair (PrimApp (TupRsingle scalarType) (A.PrimEq singleType) + (smartPair (Var idxv) (Var slv))) (genCond slix' idxvs slvs)) genCond _ _ _ = error "impossible GADTs" @@ -1126,7 +1345,9 @@ accLabelParents = \case Map _ lam e -> fromLabel e ++ lamLabels lam ZipWith _ lam e1 e2 -> fromLabel e1 ++ fromLabel e2 ++ lamLabels lam Generate _ e lam -> expLabels e ++ lamLabels lam - Fold _ lam me0 e -> lamLabels lam ++ maybe [] expLabels me0 ++ fromLabel e + Fold _ f me0 e -> funLabels f ++ maybe [] expLabels me0 ++ fromLabel e + Scan _ _ f me0 e -> funLabels f ++ maybe [] expLabels me0 ++ fromLabel e + Scan' _ _ f e0 e -> funLabels f ++ expLabels e0 ++ fromLabel e Sum _ e -> fromLabel e Replicate _ _ she e -> expLabels she ++ fromLabel e Slice _ _ e she -> fromLabel e ++ expLabels she @@ -1206,8 +1427,5 @@ smartSnd ex = Aget t2 (TIRight TIHere) ex smartSnd _ = error "smartSnd: impossible GADTs" -smartPair :: OpenAcc aenv lab alab args a -> OpenAcc aenv lab alab args b -> OpenAcc aenv lab alab args (a, b) -smartPair a b = Apair (TupRpair (atypeOf a) (atypeOf b)) a b - -smartPairE :: OpenExp env aenv lab alab args a -> OpenExp env aenv lab alab args b -> OpenExp env aenv lab alab args (a, b) -smartPairE a b = Pair (TupRpair (etypeOf a) (etypeOf b)) a b +smartPairA :: OpenAcc aenv lab alab args a -> OpenAcc aenv lab alab args b -> OpenAcc aenv lab alab args (a, b) +smartPairA a b = Apair (TupRpair (atypeOf a) (atypeOf b)) a b diff --git a/src/Data/Array/Accelerate/Trafo/AD/ADExp.hs b/src/Data/Array/Accelerate/Trafo/AD/ADExp.hs index 815ea6378..78f53988c 100644 --- a/src/Data/Array/Accelerate/Trafo/AD/ADExp.hs +++ b/src/Data/Array/Accelerate/Trafo/AD/ADExp.hs @@ -157,6 +157,9 @@ varsToArgs (TupRpair vars1 vars2) = ex2 = varsToArgs vars2 in Pair (TupRpair (etypeOf ex1) (etypeOf ex2)) ex1 ex2 +-- TODO: produceGradient should take the ExpVars value from BEFORE varsToArgs, +-- not after. That eliminates the error case if the argument is not +-- Nil/Pair/Arg. produceGradient :: DMap (Idx args) (EDLabelT Int) -> EContext Int env -> OpenExp () aenv unused unused2 args t @@ -961,27 +964,6 @@ dual' nodemap lbl (Context labelenv bindmap) contribmap = (Let lhs adjoint) expr -> trace ("\n!! " ++ show expr) undefined - where - smartPair :: OpenExp env aenv lab alab args a -> OpenExp env aenv lab alab args b -> OpenExp env aenv lab alab args (a, b) - smartPair a b = Pair (TupRpair (etypeOf a) (etypeOf b)) a b - - smartNeg :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t - smartNeg ty a = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimNeg ty) a - - smartRecip :: FloatingType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t - smartRecip ty a = PrimApp (TupRsingle (SingleScalarType (NumSingleType (FloatingNumType ty)))) (A.PrimRecip ty) a - - smartSub :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t - smartSub ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimSub ty) (smartPair a b) - - smartMul :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t - smartMul ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimMul ty) (smartPair a b) - - smartFDiv :: FloatingType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t - smartFDiv ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType (FloatingNumType ty)))) (A.PrimFDiv ty) (smartPair a b) - - smartGt :: SingleType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args A.PrimBool - smartGt ty a b = PrimApp (TupRsingle scalarType) (A.PrimGt ty) (smartPair a b) -- TODO: make a new abstraction after the refactor, possibly inspired by this function, which was the abstraction pre-refactor -- dualStoreAdjoint diff --git a/src/Data/Array/Accelerate/Trafo/AD/Acc.hs b/src/Data/Array/Accelerate/Trafo/AD/Acc.hs index d3ff7f1bd..bd64cfb82 100644 --- a/src/Data/Array/Accelerate/Trafo/AD/Acc.hs +++ b/src/Data/Array/Accelerate/Trafo/AD/Acc.hs @@ -68,7 +68,7 @@ data OpenAcc aenv lab alab args t where -> OpenAcc aenv lab alab args (Array sh t3) Fold :: ArrayR (Array sh e) - -> ExpLambda1 aenv lab alab sh (e, e) e + -> Fun aenv lab alab ((e, e) -> e) -> Maybe (Exp aenv lab alab () e) -> OpenAcc aenv lab alab args (Array (sh, Int) e) -> OpenAcc aenv lab alab args (Array sh e) @@ -77,6 +77,20 @@ data OpenAcc aenv lab alab args t where -> OpenAcc aenv lab alab args (Array (sh, Int) e) -> OpenAcc aenv lab alab args (Array sh e) + Scan :: ArrayR (Array (sh, Int) e) + -> A.Direction + -> Fun aenv lab alab ((e, e) -> e) + -> Maybe (Exp aenv lab alab () e) + -> OpenAcc aenv lab alab args (Array (sh, Int) e) + -> OpenAcc aenv lab alab args (Array (sh, Int) e) + + Scan' :: ArraysR (Array (sh, Int) e, Array sh e) + -> A.Direction + -> Fun aenv lab alab ((e, e) -> e) + -> Exp aenv lab alab () e + -> OpenAcc aenv lab alab args (Array (sh, Int) e) + -> OpenAcc aenv lab alab args (Array (sh, Int) e, Array sh e) + Generate :: ArrayR (Array sh e) -> Exp aenv lab alab () sh -> ExpLambda1 aenv lab alab sh sh e @@ -213,9 +227,21 @@ showsAcc se d (ZipWith _ f e1 e2) = showsAcc se d (Fold _ f me0 e) = showParen (d > 10) $ showString (maybe "fold1 " (const "fold ") me0) . - showsLambda (se { seEnv = [] }) 11 f . showString " " . + showsFun (se { seEnv = [] }) 11 f . showString " " . maybe id (\e0 -> showsExp (se { seEnv = [] }) 11 e0 . showString " ") me0 . showsAcc se 11 e +showsAcc se d (Scan _ dir f me0 e) = + showParen (d > 10) $ + showString ("scan" ++ (case dir of A.LeftToRight -> "l" ; A.RightToLeft -> "r") ++ maybe "1" (const "") me0 ++ " ") . + showsFun (se { seEnv = [] }) 11 f . showString " " . + maybe id (\e0 -> showsExp (se { seEnv = [] }) 11 e0 . showString " ") me0 . + showsAcc se 11 e +showsAcc se d (Scan' _ dir f e0 e) = + showParen (d > 10) $ + showString ("scan" ++ (case dir of A.LeftToRight -> "l" ; A.RightToLeft -> "r") ++ "' ") . + showsFun (se { seEnv = [] }) 11 f . showString " " . + showsExp (se { seEnv = [] }) 11 e0 . showString " " . + showsAcc se 11 e showsAcc se d (Backpermute _ dim f e) = showParen (d > 10) $ showString "backpermute " . @@ -304,6 +330,8 @@ atypeOf (Map ty _ _) = TupRsingle ty atypeOf (ZipWith ty _ _ _) = TupRsingle ty atypeOf (Generate ty _ _) = TupRsingle ty atypeOf (Fold ty _ _ _) = TupRsingle ty +atypeOf (Scan ty _ _ _ _) = TupRsingle ty +atypeOf (Scan' ty _ _ _ _) = ty atypeOf (Backpermute ty _ _ _) = TupRsingle ty atypeOf (Permute ty _ _ _ _) = TupRsingle ty atypeOf (Replicate ty _ _ _) = TupRsingle ty diff --git a/src/Data/Array/Accelerate/Trafo/AD/Exp.hs b/src/Data/Array/Accelerate/Trafo/AD/Exp.hs index e9352c0cb..fbedab6c5 100644 --- a/src/Data/Array/Accelerate/Trafo/AD/Exp.hs +++ b/src/Data/Array/Accelerate/Trafo/AD/Exp.hs @@ -470,3 +470,28 @@ mkBool b Const scalarType tag | otherwise = error $ "Bool does not have a " ++ constrName ++ " constructor?" where constrName = if b then "True" else "False" + + +smartPair :: OpenExp env aenv lab alab args a -> OpenExp env aenv lab alab args b -> OpenExp env aenv lab alab args (a, b) +smartPair a b = Pair (TupRpair (etypeOf a) (etypeOf b)) a b + +smartNeg :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t +smartNeg ty a = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimNeg ty) a + +smartRecip :: FloatingType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t +smartRecip ty a = PrimApp (TupRsingle (SingleScalarType (NumSingleType (FloatingNumType ty)))) (A.PrimRecip ty) a + +smartAdd :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t +smartAdd ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimAdd ty) (smartPair a b) + +smartSub :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t +smartSub ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimSub ty) (smartPair a b) + +smartMul :: NumType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t +smartMul ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType ty))) (A.PrimMul ty) (smartPair a b) + +smartFDiv :: FloatingType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t +smartFDiv ty a b = PrimApp (TupRsingle (SingleScalarType (NumSingleType (FloatingNumType ty)))) (A.PrimFDiv ty) (smartPair a b) + +smartGt :: SingleType t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args t -> OpenExp env aenv lab alab args A.PrimBool +smartGt ty a b = PrimApp (TupRsingle scalarType) (A.PrimGt ty) (smartPair a b) diff --git a/src/Data/Array/Accelerate/Trafo/AD/Pretty.hs b/src/Data/Array/Accelerate/Trafo/AD/Pretty.hs index c0d5201dc..e0f7c2ccf 100644 --- a/src/Data/Array/Accelerate/Trafo/AD/Pretty.hs +++ b/src/Data/Array/Accelerate/Trafo/AD/Pretty.hs @@ -1,6 +1,7 @@ {-# LANGUAGE GADTs #-} module Data.Array.Accelerate.Trafo.AD.Pretty (prettyPrint) where +import qualified Data.Array.Accelerate.AST as A import qualified Data.Array.Accelerate.AST.Var as A import Data.Array.Accelerate.Representation.Array import Data.Array.Accelerate.Trafo.AD.Acc @@ -194,9 +195,21 @@ layoutAcc se d (ZipWith _ f e1 e2) = layoutAcc se d (Fold _ f me0 e) = parenthesise (d > 10) $ lprefix (maybe "fold1 " (const "fold ") me0) - (lseq' $ concat [[layoutLambda (se { seEnv = [] }) 11 f] + (lseq' $ concat [[layoutFun (se { seEnv = [] }) 11 f] ,maybe [] (\e0 -> [layoutExp (se { seEnv = [] }) 11 e0]) me0 ,[layoutAcc se 11 e]]) +layoutAcc se d (Scan _ dir f me0 e) = + parenthesise (d > 10) $ + lprefix ("scan" ++ (case dir of A.LeftToRight -> "l" ; A.RightToLeft -> "r") ++ maybe "1" (const "") me0 ++ " ") + (lseq' $ concat [[layoutFun (se { seEnv = [] }) 11 f] + ,maybe [] (\e0 -> [layoutExp (se { seEnv = [] }) 11 e0]) me0 + ,[layoutAcc se 11 e]]) +layoutAcc se d (Scan' _ dir f e0 e) = + parenthesise (d > 10) $ + lprefix ("scan" ++ (case dir of A.LeftToRight -> "l" ; A.RightToLeft -> "r") ++ "' ") + (lseq' [layoutFun (se { seEnv = [] }) 11 f + ,layoutExp (se { seEnv = [] }) 11 e0 + ,layoutAcc se 11 e]) layoutAcc se d (Backpermute _ dim f e) = parenthesise (d > 10) $ lprefix "backpermute " diff --git a/src/Data/Array/Accelerate/Trafo/AD/Simplify.hs b/src/Data/Array/Accelerate/Trafo/AD/Simplify.hs index e583f6835..7f203119d 100644 --- a/src/Data/Array/Accelerate/Trafo/AD/Simplify.hs +++ b/src/Data/Array/Accelerate/Trafo/AD/Simplify.hs @@ -78,10 +78,15 @@ goAcc = \case Map ty lam a -> Map ty !$! simplifyLam1 lam !**! goAcc a ZipWith ty lam a1 a2 -> ZipWith ty !$! simplifyLam1 lam !**! goAcc a1 !**! goAcc a2 Generate ty e lam -> Generate ty !$! goExp' e !**! simplifyLam1 lam - Fold ty lam me0 a -> Fold ty !$! simplifyLam1 lam - !**! (case me0 of Just e0 -> Just !$! goExp' e0 - Nothing -> returnS Nothing) - !**! goAcc a + Fold ty f me0 a -> Fold ty !$! simplifyFun f + !**! (case me0 of Just e0 -> Just !$! goExp' e0 + Nothing -> returnS Nothing) + !**! goAcc a + Scan ty dir f me0 a -> Scan ty dir !$! simplifyFun f + !**! (case me0 of Just e0 -> Just !$! goExp' e0 + Nothing -> returnS Nothing) + !**! goAcc a + Scan' ty dir f e0 a -> Scan' ty dir !$! simplifyFun f !**! goExp' e0 !**! goAcc a Sum ty a -> Sum ty !$! goAcc a Replicate ty slt she a -> Replicate ty slt !$! goExp' she !**! goAcc a Slice ty slt a she -> Slice ty slt !$! goAcc a !**! goExp' she @@ -206,7 +211,9 @@ inlineA f = \case Map ty lam a -> Map ty (inlineALam f lam) (inlineA f a) ZipWith ty lam a1 a2 -> ZipWith ty (inlineALam f lam) (inlineA f a1) (inlineA f a2) Generate ty e lam -> Generate ty (inlineAE f e) (inlineALam f lam) - Fold ty lam me0 a -> Fold ty (inlineALam f lam) (inlineAE f <$> me0) (inlineA f a) + Fold ty fun me0 a -> Fold ty (inlineAEF f fun) (inlineAE f <$> me0) (inlineA f a) + Scan ty dir fun me0 a -> Scan ty dir (inlineAEF f fun) (inlineAE f <$> me0) (inlineA f a) + Scan' ty dir fun e0 a -> Scan' ty dir (inlineAEF f fun) (inlineAE f e0) (inlineA f a) Sum ty a -> Sum ty (inlineA f a) Replicate ty slt she a -> Replicate ty slt (inlineAE f she) (inlineA f a) Slice ty slt a she -> Slice ty slt (inlineA f a) (inlineAE f she) diff --git a/src/Data/Array/Accelerate/Trafo/AD/Sink.hs b/src/Data/Array/Accelerate/Trafo/AD/Sink.hs index 79425b6c0..9acc0f149 100644 --- a/src/Data/Array/Accelerate/Trafo/AD/Sink.hs +++ b/src/Data/Array/Accelerate/Trafo/AD/Sink.hs @@ -50,7 +50,9 @@ sinkAcc _ Anil = Anil sinkAcc k (Acond ty c t e) = Acond ty (sinkExpAenv k c) (sinkAcc k t) (sinkAcc k e) sinkAcc k (Map ty f e) = Map ty (sinkFunAenv k <$> f) (sinkAcc k e) sinkAcc k (ZipWith ty f e1 e2) = ZipWith ty (sinkFunAenv k <$> f) (sinkAcc k e1) (sinkAcc k e2) -sinkAcc k (Fold ty f me0 e) = Fold ty (sinkFunAenv k <$> f) (sinkExpAenv k <$> me0) (sinkAcc k e) +sinkAcc k (Fold ty f me0 e) = Fold ty (sinkFunAenv k f) (sinkExpAenv k <$> me0) (sinkAcc k e) +sinkAcc k (Scan ty dir f me0 e) = Scan ty dir (sinkFunAenv k f) (sinkExpAenv k <$> me0) (sinkAcc k e) +sinkAcc k (Scan' ty dir f e0 e) = Scan' ty dir (sinkFunAenv k f) (sinkExpAenv k e0) (sinkAcc k e) sinkAcc k (Backpermute ty dim f e) = Backpermute ty (sinkExpAenv k dim) (sinkFunAenv k f) (sinkAcc k e) sinkAcc k (Permute ty comb def pf e) = Permute ty (sinkFunAenv k comb) (sinkAcc k def) (sinkFunAenv k pf) (sinkAcc k e) sinkAcc k (Sum ty e) = Sum ty (sinkAcc k e) @@ -157,7 +159,9 @@ aCheckClosedInTagval tv expr = case expr of Acond ty c t e -> Acond ty <$> eCheckAClosedInTagval tv c <*> aCheckClosedInTagval tv t <*> aCheckClosedInTagval tv e Map ty f e -> Map ty <$> traverse (efCheckAClosedInTagval tv) f <*> aCheckClosedInTagval tv e ZipWith ty f e1 e2 -> ZipWith ty <$> traverse (efCheckAClosedInTagval tv) f <*> aCheckClosedInTagval tv e1 <*> aCheckClosedInTagval tv e2 - Fold ty f me0 e -> Fold ty <$> traverse (efCheckAClosedInTagval tv) f <*> traverse (eCheckAClosedInTagval tv) me0 <*> aCheckClosedInTagval tv e + Fold ty f me0 e -> Fold ty <$> efCheckAClosedInTagval tv f <*> traverse (eCheckAClosedInTagval tv) me0 <*> aCheckClosedInTagval tv e + Scan ty dir f me0 e -> Scan ty dir <$> efCheckAClosedInTagval tv f <*> traverse (eCheckAClosedInTagval tv) me0 <*> aCheckClosedInTagval tv e + Scan' ty dir f e0 e -> Scan' ty dir <$> efCheckAClosedInTagval tv f <*> eCheckAClosedInTagval tv e0 <*> aCheckClosedInTagval tv e Backpermute ty dim f e -> Backpermute ty <$> eCheckAClosedInTagval tv dim <*> efCheckAClosedInTagval tv f <*> aCheckClosedInTagval tv e Permute ty cf def pf e -> Permute ty <$> efCheckAClosedInTagval tv cf <*> aCheckClosedInTagval tv def <*> efCheckAClosedInTagval tv pf <*> aCheckClosedInTagval tv e Sum ty e -> Sum ty <$> aCheckClosedInTagval tv e diff --git a/src/Data/Array/Accelerate/Trafo/AD/Translate.hs b/src/Data/Array/Accelerate/Trafo/AD/Translate.hs index 9595a529d..9506d2fb5 100644 --- a/src/Data/Array/Accelerate/Trafo/AD/Translate.hs +++ b/src/Data/Array/Accelerate/Trafo/AD/Translate.hs @@ -52,7 +52,11 @@ translateAcc (A.OpenAcc expr) = case expr of | Nothing <- initval -> D.Sum (A.arrayR expr) (translateAcc e) A.Fold f me0 e -> - D.Fold (A.arrayR expr) (Right $ toPairedBinop $ translateFun f) (translateExp <$> me0) (translateAcc e) + D.Fold (A.arrayR expr) (toPairedBinop $ translateFun f) (translateExp <$> me0) (translateAcc e) + A.Scan dir f me0 e -> + D.Scan (A.arrayR expr) dir (toPairedBinop $ translateFun f) (translateExp <$> me0) (translateAcc e) + A.Scan' dir f e0 e -> + D.Scan' (A.arraysR expr) dir (toPairedBinop $ translateFun f) (translateExp e0) (translateAcc e) A.Generate ty she f -> D.Generate ty (translateExp she) (Right $ translateFun f) A.Replicate slt sle e -> @@ -228,7 +232,9 @@ untranslateLHSboundAcc toplhs topexpr D.Acond _ e1 e2 e3 -> A.Acond (untranslateClosedExpA e1 pv) (go e2 pv) (go e3 pv) D.Map (ArrayR _ ty) (Right f) e -> A.Map ty (untranslateClosedFunA f pv) (go e pv) D.ZipWith (ArrayR _ ty) (Right f) e1 e2 -> A.ZipWith ty (untranslateClosedFunA (fromPairedBinop f) pv) (go e1 pv) (go e2 pv) - D.Fold _ (Right f) me0 e -> A.Fold (untranslateClosedFunA (fromPairedBinop f) pv) (untranslateClosedExpA <$> me0 <*> Just pv) (go e pv) + D.Fold _ f me0 e -> A.Fold (untranslateClosedFunA (fromPairedBinop f) pv) (untranslateClosedExpA <$> me0 <*> Just pv) (go e pv) + D.Scan _ dir f me0 e -> A.Scan dir (untranslateClosedFunA (fromPairedBinop f) pv) (untranslateClosedExpA <$> me0 <*> Just pv) (go e pv) + D.Scan' _ dir f e0 e -> A.Scan' dir (untranslateClosedFunA (fromPairedBinop f) pv) (untranslateClosedExpA e0 pv) (go e pv) D.Sum (ArrayR _ (TupRsingle ty@(SingleScalarType (NumSingleType nt)))) e -> A.Fold (A.Lam (A.LeftHandSideSingle ty) (A.Lam (A.LeftHandSideSingle ty) (A.Body (A.PrimApp (A.PrimAdd nt) @@ -261,7 +267,6 @@ untranslateLHSboundAcc toplhs topexpr D.Alabel _ -> internalError "AD.untranslateLHSboundAcc: Unexpected Label in untranslate!" D.Map _ _ _ -> error "Unexpected Map shape in untranslate" D.ZipWith _ _ _ _ -> error "Unexpected ZipWith shape in untranslate" - D.Fold _ _ _ _ -> error "Unexpected Fold shape in untranslate" D.Sum _ _ -> error "Unexpected Sum shape in untranslate" D.Generate _ _ _ -> error "Unexpected Generate shape in untranslate" D.Reduce _ _ _ _ -> error "Unexpected Reduce shape in untranslate" diff --git a/test/tom/Main.hs b/test/tom/Main.hs index af4c4576b..51de7c51d 100644 --- a/test/tom/Main.hs +++ b/test/tom/Main.hs @@ -287,6 +287,21 @@ arrad = do A.gradientA (\arr -> A.sum $ A.zipWith (*) arr (A.generate (A.shape arr) (\(A.I1 i) -> A.cond (i A.< 5) 0 1))) (A.use (A.fromList (Z :. 10) [1 :: Float .. 10])) +adfold :: IO () +adfold = do + -- print $ + -- A.gradientA (\arr -> A.maximum arr) + -- (A.use (A.fromList (Z :. 10) [1 :: Float .. 10])) + let input = A.fromList (Z :. 8) [1 :: Float .. 8] + + AD.aCompareAD (\arr -> A.maximum arr) input + AD.aCompareAD (\arr -> A.product arr) input + + print $ + A.gradientA (\arr -> let p = A.product arr + in A.zipWith (+) (A.map (*2) p) (A.map (*3) p)) + (A.use input) + main :: IO () main = do -- logistic @@ -300,7 +315,8 @@ main = do -- adtuple1 -- adtuple2 -- adtuple3 - arrad + -- arrad + adfold -- neural -- neural2 -- Playground.Neural.main