Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into node-labels
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Nov 28, 2024
2 parents 695ff5b + cca5a1d commit d28a6e3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 39 deletions.
72 changes: 36 additions & 36 deletions brat/Brat/Checker/Helpers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import Util (log2)

import Control.Monad.Freer (req)
import Data.Bifunctor
import Data.Foldable (foldrM)
import Data.Type.Equality (TestEquality(..), (:~:)(..))
import qualified Data.Map as M
import Prelude hiding (last)
Expand Down Expand Up @@ -247,22 +248,19 @@ getThunks :: Modey m
,Overs m UVerb
)
getThunks _ [] = pure ([], [], [])
getThunks Braty row@((src, Right ty):rest) = (eval S0 ty >>= vectorise . (src,)) >>= \case
(src, VFun Braty (ss :->> ts)) -> do
(node, unders, overs, _) <- let ?my = Braty in
anext "Eval" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Braty rest
pure (node:nodes, unders <> unders', overs <> overs')
-- These shouldn't happen
(_, VFun _ _) -> err $ ExpectedThunk (showMode Braty) (showRow row)
v -> typeErr $ "Force called on non-thunk: " ++ show v
getThunks Kerny row@((src, Right ty):rest) = (eval S0 ty >>= vectorise . (src,)) >>= \case
(src, VFun Kerny (ss :->> ts)) -> do
(node, unders, overs, _) <- let ?my = Kerny in anext "Splice" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Kerny rest
pure (node:nodes, unders <> unders', overs <> overs')
(_, VFun _ _) -> err $ ExpectedThunk (showMode Kerny) (showRow row)
v -> typeErr $ "Force called on non-(kernel)-thunk: " ++ show v
getThunks Braty ((src, Right ty):rest) = do
ty <- eval S0 ty
(src, (ss :->> ts)) <- vectorise Braty (src, ty)

Check warning on line 253 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / hlint

Suggestion in getThunks in module Brat.Checker.Helpers: Redundant bracket ▫︎ Found: "(src, (ss :->> ts))" ▫︎ Perhaps: "(src, ss :->> ts)"
(node, unders, overs, _) <- let ?my = Braty in
anext "Eval" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Braty rest
pure (node:nodes, unders <> unders', overs <> overs')
getThunks Kerny ((src, Right ty):rest) = do
ty <- eval S0 ty
(src, (ss :->> ts)) <- vectorise Kerny (src,ty)

Check warning on line 260 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / hlint

Suggestion in getThunks in module Brat.Checker.Helpers: Redundant bracket ▫︎ Found: "(src, (ss :->> ts))" ▫︎ Perhaps: "(src, ss :->> ts)"
(node, unders, overs, _) <- let ?my = Kerny in anext "Splice" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Kerny rest
pure (node:nodes, unders <> unders', overs <> overs')
getThunks Braty ((src, Left (Star args)):rest) = do
(node, unders, overs) <- case bwdStack (B0 <>< args) of
Some (_ :* stk) -> do
Expand All @@ -274,15 +272,15 @@ getThunks Braty ((src, Left (Star args)):rest) = do
getThunks m ro = err $ ExpectedThunk (showMode m) (showRow ro)

-- The type given here should be normalised
vecLayers :: Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers
,Some (Modey :* Flip CTy Z) -- The function type at the end
)
vecLayers (TVec ty (VNum n)) = do
vecLayers :: Modey m -> Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers
,CTy m Z -- The function type at the end
)
vecLayers my (TVec ty (VNum n)) = do
src <- mkStaticNum n
(layers, fun) <- vecLayers ty
pure ((src, n):layers, fun)
vecLayers (VFun my cty) = pure ([], Some (my :* Flip cty))
vecLayers ty = typeErr $ "Expected a function or vector of functions, got " ++ show ty
first ((src, n):) <$> vecLayers my ty
vecLayers Braty (VFun Braty cty) = pure ([], cty)
vecLayers Kerny (VFun Kerny cty) = pure ([], cty)
vecLayers my ty = typeErr $ "Expected a " ++ showMode my ++ "function or vector of functions, got " ++ show ty

mkStaticNum :: NumVal (VVar Z) -> Checking Src
mkStaticNum n@(NumValue c gro) = do
Expand Down Expand Up @@ -330,27 +328,29 @@ mkStaticNum n@(NumValue c gro) = do
wire (oneSrc, TNat, rhs)
pure src

vectorise :: (Src, Val Z) -> Checking (Src, Val Z)
vectorise (src, ty) = do
(layers, Some (my :* Flip cty)) <- vecLayers ty
modily my $ mkMapFuns (src, VFun my cty) layers
vectorise :: forall m. Modey m -> (Src, Val Z) -> Checking (Src, CTy m Z)
vectorise my (src, ty) = do
(layers, cty) <- vecLayers my ty
modily my $ foldrM mkMapFun (src, cty) layers
where
mkMapFuns :: (Src, Val Z) -- The input to the mapfun
-> [(Src, NumVal (VVar Z))] -- Remaining layers
-> Checking (Src, Val Z)
mkMapFuns over [] = pure over
mkMapFuns (valSrc, ty) ((lenSrc, len):layers) = do
(valSrc, ty@(VFun my cty)) <- mkMapFuns (valSrc, ty) layers
mkMapFun :: (Src, NumVal (VVar Z)) -- Layer to apply
-> (Src, CTy m Z) -- The input to this level of mapfun
-> Checking (Src, CTy m Z)
mkMapFun (lenSrc, len) (valSrc, cty) = do
let weak1 = changeVar (Thinning (ThDrop ThNull))

Check warning on line 340 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘weak1’

Check warning on line 340 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘weak1’
vecFun <- vectorisedFun len my cty
(_, [(lenTgt,_), (valTgt, _)], [(vectorSrc, Right vecTy)], _) <-
(_, [(lenTgt,_), (valTgt, _)], [(vectorSrc, Right (VFun my' cty))], _) <-
next "MapFun" MapFun (S0, Some (Zy :* S0))
(REx ("len", Nat) (RPr ("value", weak1 ty) R0))
(RPr ("vector", weak1 vecFun) R0)
defineTgt lenTgt (VNum len)
wire (lenSrc, kindType Nat, lenTgt)
wire (valSrc, ty, valTgt)
pure (vectorSrc, vecTy)
let vecCTy = case (my,my',cty) of
(Braty,Braty,cty) -> cty
(Kerny,Kerny,cty) -> cty
_ -> error "next returned wrong mode of computation type to that passed in"
pure (vectorSrc, vecCTy)

vectorisedFun :: NumVal (VVar Z) -> Modey m -> CTy m Z -> Checking (Val Z)
vectorisedFun nv my (ss :->> ts) = do
Expand Down
4 changes: 1 addition & 3 deletions brat/test/golden/kernel/kernel_application.brat.golden
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,5 @@ Error in test/golden/kernel/kernel_application.brat on line 16:
rotate = { q => maybeRotate(true) }
^^^^^^^^^^^

Expected function to be a (kernel) thunk, but found:
(thunk :: { (a1 :: Bool) -> (a1 :: { (a1 :: Qubit) -o (a1 :: Qubit) }) })

Type error: Expected a (kernel) function or vector of functions, got { (a1 :: Bool) -> (a1 :: { (a1 :: Qubit) -o (a1 :: Qubit) }) }

0 comments on commit d28a6e3

Please sign in to comment.