diff --git a/brat/Brat/Checker/Helpers.hs b/brat/Brat/Checker/Helpers.hs index 6428d814..b31d2457 100644 --- a/brat/Brat/Checker/Helpers.hs +++ b/brat/Brat/Checker/Helpers.hs @@ -38,6 +38,7 @@ import Util (log2) import Control.Monad.Freer (req) import Data.Bifunctor +import Data.Foldable (foldrM) import Data.Type.Equality (TestEquality(..), (:~:)(..)) import qualified Data.Map as M import Prelude hiding (last) @@ -247,22 +248,19 @@ getThunks :: Modey m ,Overs m UVerb ) getThunks _ [] = pure ([], [], []) -getThunks Braty row@((src, Right ty):rest) = (eval S0 ty >>= vectorise . (src,)) >>= \case - (src, VFun Braty (ss :->> ts)) -> do - (node, unders, overs, _) <- let ?my = Braty in - anext "Eval" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts - (nodes, unders', overs') <- getThunks Braty rest - pure (node:nodes, unders <> unders', overs <> overs') - -- These shouldn't happen - (_, VFun _ _) -> err $ ExpectedThunk (showMode Braty) (showRow row) - v -> typeErr $ "Force called on non-thunk: " ++ show v -getThunks Kerny row@((src, Right ty):rest) = (eval S0 ty >>= vectorise . (src,)) >>= \case - (src, VFun Kerny (ss :->> ts)) -> do - (node, unders, overs, _) <- let ?my = Kerny in anext "Splice" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts - (nodes, unders', overs') <- getThunks Kerny rest - pure (node:nodes, unders <> unders', overs <> overs') - (_, VFun _ _) -> err $ ExpectedThunk (showMode Kerny) (showRow row) - v -> typeErr $ "Force called on non-(kernel)-thunk: " ++ show v +getThunks Braty ((src, Right ty):rest) = do + ty <- eval S0 ty + (src, (ss :->> ts)) <- vectorise Braty (src, ty) + (node, unders, overs, _) <- let ?my = Braty in + anext "Eval" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts + (nodes, unders', overs') <- getThunks Braty rest + pure (node:nodes, unders <> unders', overs <> overs') +getThunks Kerny ((src, Right ty):rest) = do + ty <- eval S0 ty + (src, (ss :->> ts)) <- vectorise Kerny (src,ty) + (node, unders, overs, _) <- let ?my = Kerny in anext "Splice" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts + (nodes, unders', overs') <- getThunks Kerny rest + pure (node:nodes, unders <> unders', overs <> overs') getThunks Braty ((src, Left (Star args)):rest) = do (node, unders, overs) <- case bwdStack (B0 <>< args) of Some (_ :* stk) -> do @@ -274,15 +272,15 @@ getThunks Braty ((src, Left (Star args)):rest) = do getThunks m ro = err $ ExpectedThunk (showMode m) (showRow ro) -- The type given here should be normalised -vecLayers :: Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers - ,Some (Modey :* Flip CTy Z) -- The function type at the end - ) -vecLayers (TVec ty (VNum n)) = do +vecLayers :: Modey m -> Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers + ,CTy m Z -- The function type at the end + ) +vecLayers my (TVec ty (VNum n)) = do src <- mkStaticNum n - (layers, fun) <- vecLayers ty - pure ((src, n):layers, fun) -vecLayers (VFun my cty) = pure ([], Some (my :* Flip cty)) -vecLayers ty = typeErr $ "Expected a function or vector of functions, got " ++ show ty + first ((src, n):) <$> vecLayers my ty +vecLayers Braty (VFun Braty cty) = pure ([], cty) +vecLayers Kerny (VFun Kerny cty) = pure ([], cty) +vecLayers my ty = typeErr $ "Expected a " ++ showMode my ++ "function or vector of functions, got " ++ show ty mkStaticNum :: NumVal (VVar Z) -> Checking Src mkStaticNum n@(NumValue c gro) = do @@ -330,27 +328,29 @@ mkStaticNum n@(NumValue c gro) = do wire (oneSrc, TNat, rhs) pure src -vectorise :: (Src, Val Z) -> Checking (Src, Val Z) -vectorise (src, ty) = do - (layers, Some (my :* Flip cty)) <- vecLayers ty - modily my $ mkMapFuns (src, VFun my cty) layers +vectorise :: forall m. Modey m -> (Src, Val Z) -> Checking (Src, CTy m Z) +vectorise my (src, ty) = do + (layers, cty) <- vecLayers my ty + modily my $ foldrM mkMapFun (src, cty) layers where - mkMapFuns :: (Src, Val Z) -- The input to the mapfun - -> [(Src, NumVal (VVar Z))] -- Remaining layers - -> Checking (Src, Val Z) - mkMapFuns over [] = pure over - mkMapFuns (valSrc, ty) ((lenSrc, len):layers) = do - (valSrc, ty@(VFun my cty)) <- mkMapFuns (valSrc, ty) layers + mkMapFun :: (Src, NumVal (VVar Z)) -- Layer to apply + -> (Src, CTy m Z) -- The input to this level of mapfun + -> Checking (Src, CTy m Z) + mkMapFun (lenSrc, len) (valSrc, cty) = do let weak1 = changeVar (Thinning (ThDrop ThNull)) vecFun <- vectorisedFun len my cty - (_, [(lenTgt,_), (valTgt, _)], [(vectorSrc, Right vecTy)], _) <- + (_, [(lenTgt,_), (valTgt, _)], [(vectorSrc, Right (VFun my' cty))], _) <- next "MapFun" MapFun (S0, Some (Zy :* S0)) (REx ("len", Nat) (RPr ("value", weak1 ty) R0)) (RPr ("vector", weak1 vecFun) R0) defineTgt lenTgt (VNum len) wire (lenSrc, kindType Nat, lenTgt) wire (valSrc, ty, valTgt) - pure (vectorSrc, vecTy) + let vecCTy = case (my,my',cty) of + (Braty,Braty,cty) -> cty + (Kerny,Kerny,cty) -> cty + _ -> error "next returned wrong mode of computation type to that passed in" + pure (vectorSrc, vecCTy) vectorisedFun :: NumVal (VVar Z) -> Modey m -> CTy m Z -> Checking (Val Z) vectorisedFun nv my (ss :->> ts) = do diff --git a/brat/test/golden/kernel/kernel_application.brat.golden b/brat/test/golden/kernel/kernel_application.brat.golden index ed0aa8bf..af09def9 100644 --- a/brat/test/golden/kernel/kernel_application.brat.golden +++ b/brat/test/golden/kernel/kernel_application.brat.golden @@ -2,7 +2,5 @@ Error in test/golden/kernel/kernel_application.brat on line 16: rotate = { q => maybeRotate(true) } ^^^^^^^^^^^ - Expected function to be a (kernel) thunk, but found: - (thunk :: { (a1 :: Bool) -> (a1 :: { (a1 :: Qubit) -o (a1 :: Qubit) }) }) - + Type error: Expected a (kernel) function or vector of functions, got { (a1 :: Bool) -> (a1 :: { (a1 :: Qubit) -o (a1 :: Qubit) }) }