Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] vectorise: take expected mode, make return type explicit #58

Merged
merged 6 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 37 additions & 37 deletions brat/Brat/Checker/Helpers.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE AllowAmbiguousTypes, ScopedTypeVariables #-}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scoped type variables should be on all the time, since it's in GHC2021 in the cabal file


module Brat.Checker.Helpers {-(pullPortsRow, pullPortsSig
,simpleCheck
Expand Down Expand Up @@ -39,6 +39,7 @@

import Control.Monad.Freer (req, Free(Ret))
import Data.Bifunctor
import Data.Foldable (foldrM)
import Data.List (intercalate)
import Data.Type.Equality (TestEquality(..), (:~:)(..))
import qualified Data.Map as M
Expand Down Expand Up @@ -258,22 +259,19 @@
,Overs m UVerb
)
getThunks _ [] = pure ([], [], [])
getThunks Braty row@((src, Right ty):rest) = (eval S0 ty >>= vectorise . (src,)) >>= \case
(src, VFun Braty (ss :->> ts)) -> do
(node, unders, overs, _) <- let ?my = Braty in
anext "" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Braty rest
pure (node:nodes, unders <> unders', overs <> overs')
-- These shouldn't happen
(_, VFun _ _) -> err $ ExpectedThunk (showMode Braty) (showRow row)
v -> typeErr $ "Force called on non-thunk: " ++ show v
getThunks Kerny row@((src, Right ty):rest) = (eval S0 ty >>= vectorise . (src,)) >>= \case
(src, VFun Kerny (ss :->> ts)) -> do
(node, unders, overs, _) <- let ?my = Kerny in anext "" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Kerny rest
pure (node:nodes, unders <> unders', overs <> overs')
(_, VFun _ _) -> err $ ExpectedThunk (showMode Kerny) (showRow row)
v -> typeErr $ "Force called on non-(kernel)-thunk: " ++ show v
getThunks Braty ((src, Right ty):rest) = do
ty <- eval S0 ty
(src, (ss :->> ts)) <- vectorise Braty (src, ty)

Check warning on line 264 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / hlint

Suggestion in getThunks in module Brat.Checker.Helpers: Redundant bracket ▫︎ Found: "(src, (ss :->> ts))" ▫︎ Perhaps: "(src, ss :->> ts)"
(node, unders, overs, _) <- let ?my = Braty in
anext "" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Braty rest
pure (node:nodes, unders <> unders', overs <> overs')
getThunks Kerny ((src, Right ty):rest) = do
ty <- eval S0 ty
(src, (ss :->> ts)) <- vectorise Kerny (src,ty)

Check warning on line 271 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / hlint

Suggestion in getThunks in module Brat.Checker.Helpers: Redundant bracket ▫︎ Found: "(src, (ss :->> ts))" ▫︎ Perhaps: "(src, ss :->> ts)"
(node, unders, overs, _) <- let ?my = Kerny in anext "" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Kerny rest
pure (node:nodes, unders <> unders', overs <> overs')
getThunks Braty ((src, Left (Star args)):rest) = do
(node, unders, overs) <- case bwdStack (B0 <>< args) of
Some (_ :* stk) -> do
Expand All @@ -285,15 +283,15 @@
getThunks m ro = err $ ExpectedThunk (showMode m) (showRow ro)

-- The type given here should be normalised
vecLayers :: Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers
,Some (Modey :* Flip CTy Z) -- The function type at the end
)
vecLayers (TVec ty (VNum n)) = do
vecLayers :: Modey m -> Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers
,CTy m Z -- The function type at the end
)
vecLayers my (TVec ty (VNum n)) = do
src <- mkStaticNum n
(layers, fun) <- vecLayers ty
pure ((src, n):layers, fun)
vecLayers (VFun my cty) = pure ([], Some (my :* Flip cty))
vecLayers ty = typeErr $ "Expected a function or vector of functions, got " ++ show ty
first ((src, n):) <$> vecLayers my ty
vecLayers Braty (VFun Braty cty) = pure ([], cty)
vecLayers Kerny (VFun Kerny cty) = pure ([], cty)
vecLayers my ty = typeErr $ "Expected a " ++ showMode my ++ "function or vector of functions, got " ++ show ty

mkStaticNum :: NumVal (VVar Z) -> Checking Src
mkStaticNum n@(NumValue c gro) = do
Expand Down Expand Up @@ -323,7 +321,7 @@
pure src

mkMono :: Monotone (VVar Z) -> Checking Src
mkMono (Linear (VPar (ExEnd e))) = pure (NamedPort e "mono")

Check warning on line 324 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive

Check warning on line 324 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive
mkMono (Full sm) = do
(_, [], [(twoSrc,_)], _) <- next "2" (Const (Num 2)) (S0, Some (Zy :* S0)) R0 (RPr ("value", TNat) R0)
(_, [(lhs,_),(rhs,_)], [(powSrc,_)], _) <- next "2^" (ArithNode Pow) (S0, Some (Zy :* S0))
Expand All @@ -341,27 +339,29 @@
wire (oneSrc, TNat, rhs)
pure src

vectorise :: (Src, Val Z) -> Checking (Src, Val Z)
vectorise (src, ty) = do
(layers, Some (my :* Flip cty)) <- vecLayers ty
modily my $ mkMapFuns (src, VFun my cty) layers
vectorise :: forall m. Modey m -> (Src, Val Z) -> Checking (Src, CTy m Z)
vectorise my (src, ty) = do
(layers, cty) <- vecLayers my ty
modily my $ foldrM mkMapFun (src, cty) layers
where
mkMapFuns :: (Src, Val Z) -- The input to the mapfun
-> [(Src, NumVal (VVar Z))] -- Remaining layers
-> Checking (Src, Val Z)
mkMapFuns over [] = pure over
mkMapFuns (valSrc, ty) ((lenSrc, len):layers) = do
(valSrc, ty@(VFun my cty)) <- mkMapFuns (valSrc, ty) layers
mkMapFun :: (Src, NumVal (VVar Z)) -- Layer to apply
-> (Src, CTy m Z) -- The input to this level of mapfun
-> Checking (Src, CTy m Z)
mkMapFun (lenSrc, len) (valSrc, cty) = do
let weak1 = changeVar (Thinning (ThDrop ThNull))

Check warning on line 351 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘weak1’

Check warning on line 351 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘weak1’
vecFun <- vectorisedFun len my cty
(_, [(lenTgt,_), (valTgt, _)], [(vectorSrc, Right vecTy)], _) <-
(_, [(lenTgt,_), (valTgt, _)], [(vectorSrc, Right (VFun my' cty))], _) <-
next "" MapFun (S0, Some (Zy :* S0))
(REx ("len", Nat) (RPr ("value", weak1 ty) R0))
(RPr ("vector", weak1 vecFun) R0)
defineTgt lenTgt (VNum len)
wire (lenSrc, kindType Nat, lenTgt)
wire (valSrc, ty, valTgt)
pure (vectorSrc, vecTy)
let vecCTy = case (my,my',cty) of
(Braty,Braty,cty) -> cty
(Kerny,Kerny,cty) -> cty
_ -> error "next returned wrong mode of computation type to that passed in"
pure (vectorSrc, vecCTy)

vectorisedFun :: NumVal (VVar Z) -> Modey m -> CTy m Z -> Checking (Val Z)
vectorisedFun nv my (ss :->> ts) = do
Expand Down
4 changes: 1 addition & 3 deletions brat/test/golden/kernel/kernel_application.brat.golden
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,5 @@ Error in test/golden/kernel/kernel_application.brat on line 16:
rotate = { q => maybeRotate(true) }
^^^^^^^^^^^

Expected function to be a (kernel) thunk, but found:
(thunk :: { (a1 :: Bool) -> (a1 :: { (a1 :: Qubit) -o (a1 :: Qubit) }) })

Type error: Expected a (kernel) function or vector of functions, got { (a1 :: Bool) -> (a1 :: { (a1 :: Qubit) -o (a1 :: Qubit) }) }

Loading