Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow vector patterns to be used in kernels #38

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions brat/Brat/Checker.hs
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,9 @@ detectVecErrors :: UserName -- Term constructor name
-> Checking (Error -> Error) -- Returns error wrapper to use for recursion
detectVecErrors vcon (PrefixName [] "Vec") [_, VNum n] [_, VPNum p] ty tp =
case numMatch B0 n p of
Left (NumMatchFail _ _) -> do
p' <- toLenConstr p
err $ getVecErr tp (show ty) (show n) p'
Left (NumMatchFail _ _) -> case (toLenConstr p) of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So whilst continuing the slightly-hacked-in nature here and a better framework might be nice I don't think this is all that bad, certainly I'm not objecting to patching it in like this...

Right p' -> err $ getVecErr tp (show ty) (show n) p'
Left p' -> err $ InternalError ("detectVecErrors: Unexpected pattern: " ++ show (toLenConstr p'))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this toLenConstr p'? This seems odd. So after toLenConstr has stripped away any outer (NP1Plus?) NPTwoTimess from the input p to give p', we then call it again to construct an error message? Is that right?

I'm hoping this should/could just be show p' in which case I have a different suggestion - how about moving the error-construction into toLenConstr (so it's :: NumPat -> Either Error LengthConstraint) and then using throwLeft here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it's a bit hard to follow as is

-- Even if we succed here, the error might pop up when checking the
-- rest of the vector. We return a function here that intercepts the
-- error and extends it to the whole vector.
Expand All @@ -877,13 +877,14 @@ detectVecErrors vcon (PrefixName [] "Vec") [_, VNum n] [_, VPNum p] ty tp =
pure (consError fc tp (show ty) n)
else pure id
where
-- For constructors that produce something of type Vec we should
-- only ever get the patterns NP0 (if vcon == PrefixName [] "nil")
-- and NP1Plus (if vcon == PrefixName [] "cons")
toLenConstr :: NumPat -> Checking LengthConstraint
-- Try to work out the length of a vector
-- We only want to know if the vector length is nil or a successor
toLenConstr :: NumPat -> Either NumPat (LengthConstraint)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Superfluous extra brackets here, at least

toLenConstr NP0 = pure $ Length 0
toLenConstr (NP1Plus (NP2Times np)) = either (const (Right LengthOdd)) Right (toLenConstr np)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok...so....what does this do.....

  • If toLenConstr np returns (Left error), we discard the error and return pure LengthOdd
  • If toLenConstr np returns a Right value, we leave it unchanged

Is that right? (This is quite hard to follow)

If so, how about pure $ fromRight LengthOdd (toLenConstr np) ?

toLenConstr (NP1Plus _) = pure $ LongerThan 0
toLenConstr p = err $ InternalError ("detectVecErrors: Unexpected pattern: " ++ show p)
toLenConstr (NP2Times np) = either (const (Right LengthEven)) Right (toLenConstr np)
toLenConstr p = Left p
detectVecErrors _ _ _ _ _ _ = pure id

getVecErr :: Either (Term d k) Pattern -> (String -> String -> LengthConstraint -> ErrorMsg)
Expand Down
22 changes: 22 additions & 0 deletions brat/Brat/Constructors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,28 @@ kernelConstructors = M.fromList
(RPr ("tail", TVec (VApp (VInx (VS VZ)) B0) (VNum $ nVar (VInx VZ)))
(RPr ("head", VApp (VInx (VS VZ)) B0) R0)))
])
,(CConcatEqEven, M.fromList
[(CVec, CArgs [VPVar, VPNum (NP2Times NPVar)] (Sy (Sy Zy))
-- Star should be a TypeFor m forall m?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really sure I understand this comment...what happens if you replace the star with that TypeFor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try and push something to clear this up

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea was that you could split on classical variables in the kernel, like

divConquer(n :: #, Vec({ Qubit -o Qubit }, n)) -> { Vec(Qubit, n) -o Vec(Qubit, n) }
divConquer(0, []) = { .. }
divConquer(doub(n), gs) = { qs0 =,= qs1 =>
  let gs0 =,= gs1 = gs in divConquer(n, xs0)(gs0) =,= divConquer(n, gs1)(qs1)
}

but having just tried that, there are more blockers (gs isn't in scope in a kernel context anyway!) so I'll remove the comments.
Also, it'd be nice to be able to do this, as long as the pattern match is irrefutable... I guess we should come back to it after we have checking coverage of patterns

(REx ("elementType", Star []) ((REx ("halfLength", Nat) (R0))))
(RPr ("lhs", TVec (VApp (VInx (VS VZ)) B0) (VApp (VInx VZ) B0))
(RPr ("rhs", TVec (VApp (VInx (VS VZ)) B0) (VApp (VInx VZ) B0)) R0)))
])
,(CRiffle, M.fromList
[(CVec, CArgs [VPVar, VPNum (NP2Times NPVar)] (Sy (Sy Zy))
-- Star should be a TypeFor m forall m?
(REx ("elementType", Star []) ((REx ("halfLength", Nat) (R0))))
(RPr ("evens", TVec (VApp (VInx (VS VZ)) B0) (VApp (VInx VZ) B0))
(RPr ("odds", TVec (VApp (VInx (VS VZ)) B0) (VApp (VInx VZ) B0)) R0)))
])
,(CConcatEqOdd, M.fromList
[(CVec, CArgs [VPVar, VPNum (NP1Plus (NP2Times NPVar))] (Sy (Sy Zy))
-- Star should be a TypeFor m forall m?
(REx ("elementType", Star []) ((REx ("halfLength", Nat) (R0))))
(RPr ("lhs", TVec (VApp (VInx (VS VZ)) B0) (VApp (VInx VZ) B0))
(RPr ("mid", VApp (VInx (VS VZ)) B0)
(RPr ("rhs", TVec (VApp (VInx (VS VZ)) B0) (VApp (VInx VZ) B0)) R0))))
])
,(CTrue, M.fromList [(CBit, CArgs [] Zy R0 R0)])
,(CFalse, M.fromList [(CBit, CArgs [] Zy R0 R0)])
]
Expand Down
12 changes: 7 additions & 5 deletions brat/Brat/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
instance Show ParseError where
show = pretty

data LengthConstraintF a = Length a | LongerThan a deriving (Eq, Functor)
data LengthConstraintF a = Length a | LongerThan a | LengthEven | LengthOdd deriving (Eq, Functor)
instance Show a => Show (LengthConstraintF a) where
show (Length a) = show a
show (LongerThan a) = "(> " ++ show a ++ ")"
show (Length a) = "of length " ++ show a
show (LongerThan a) = "with length (> " ++ show a ++ ")"
show LengthEven = "of even length"
show LengthOdd = "of odd length"

type LengthConstraint = LengthConstraintF Int

Expand Down Expand Up @@ -106,12 +108,12 @@
show (VecLength tm ty exp act) = unlines ["Expected vector of length " ++ exp
,"from the type: " ++ ty
,"but got vector: " ++ tm
,"of length " ++ show act
,show act
]
show (VecPatLength abs ty exp act) = unlines ["Pattern: " ++ abs
,"doesn't match type " ++ ty
,"(expected vector pattern of length " ++ exp ++
" but got vector pattern of length " ++ show act ++ ")"
" but got vector pattern " ++ show act ++ ")"
]
show (NotVecPat tm ty)= unwords ["Expected", tm
,"to be a vector pattern when binding type", ty]
Expand Down Expand Up @@ -167,7 +169,7 @@
show UnreachableBranch = "Branch cannot be reached"
show (CompilingHoles hs) = unlines ("Can't compile file with remaining holes": indent hs)
where
indent = fmap (" " ++)

Check warning on line 172 in brat/Brat/Error.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘indent’

Check warning on line 172 in brat/Brat/Error.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘indent’
show (ThunkLeftOvers overs) = "Expected function to address all inputs, but " ++ overs ++ " wasn't used"
show (ThunkLeftUnders unders) = "Expected function to return additional values of type: " ++ unders

Expand Down Expand Up @@ -209,8 +211,8 @@
ls = lines contents
in case endLineN - startLineN of
0 -> [ls!!startLineN, highlightSection startCol endCol]
n | n > 0 -> let (first:rest) = drop (startLineN - 1) $ take (endLineN + 1) ls

Check warning on line 214 in brat/Brat/Error.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive

Check warning on line 214 in brat/Brat/Error.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive
(last:rmid) = reverse rest

Check warning on line 215 in brat/Brat/Error.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive

Check warning on line 215 in brat/Brat/Error.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive
in [first, highlightSection startCol (length first)]
++ (reverse rmid >>= (\l -> [l, highlightSection 0 (length l)]))
++ [last, highlightSection 0 endCol]
Expand Down
20 changes: 20 additions & 0 deletions brat/examples/brick.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
-- Create a "brickwork" state
-- Apply a parameterised unitary U to entangle every adjacent pair of qubits in a line architecture.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eeek - skipping dependent types to ensure weights match in length, is this roughly just
brick(th ,- ths, U) = {q0 ,- q1 ,- qrest => let (q0,q1) = U(th)(q0,q1) in q0 ,- brick(ths, U)(q1 ,- rest)? (I.e. gate the first two qubits, then proceed "down the line" - what I thought was called a "ladder" architecture)?

If so then whilst a fun bit of list comprehensions they don't really do much useful (you're just riffling things apart and back together, and so on), right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you're right, it's just a bit of fun that tests vector patterns are working in kernels 🙂

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Terminology-wise, I think this is called a brickwork pattern and a ladder is when everything else is sequentially entangled with 1 qubit

-- (don't apply the gate to (q0, qn-1))
brick(n :: #, -- The number of entangling gates
Vec(Float, n), -- Angles for each of the gates
U :: { Float -> { Qubit, Qubit -o Qubit, Qubit }}) -- Parameterised unitary
-> { Vec(Qubit, n + 1) -o Vec(Qubit, n + 1) }
brick(0, [], _) = { .. }
-- Odd number of gates, even number of qubits
brick(succ(doub(n)), th ,- ths, U) = {
(q0 ,- qsEven) =%= (q1 ,- qsOdd) =>
let q0, q1 = U(th)(q0, q1) in
q0 ,- brick(doub(n), ths, U)(q1 ,- (qsEven =%= qsOdd))
}
-- Even number of gates, odd number of qubits
brick(doub(succ(n)), th ,- ths, U) = { (q0 ,- qsl) =, qmid ,= qsr =>
let q1 ,- qs = brick(succ(doub(n)), ths, U)((qsl -, qmid) =,= qsr) in
let q0, q1 = U(th)(q0, q1) in
q0 ,- q1 ,- qs
}
3 changes: 2 additions & 1 deletion brat/test/Test/Compile/Hugr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ nonCompilingExamples = (expectedCheckingFails ++ expectedParsingFails ++
-- Conjecture: These examples don't compile because number patterns in type
-- signatures causes `kindCheck` to call `abstract`, creating "Selector"
-- nodes, which we don't attempt to compile because we want to get rid of them
,"vec-pats"
,"brick" -- Creates Selectors

-- Victims of #389
,"arith"
,"bell"
Expand Down
2 changes: 1 addition & 1 deletion brat/test/golden/error/badvec3.brat.golden
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ v3 = cons(1, nil)
Expected vector of length 0
from the type: Vec(Int, 0)
but got vector: [1]
of length (> 0)
with length (> 0)


5 changes: 5 additions & 0 deletions brat/test/golden/error/even-length.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
f(X :: *, n :: #, Vec(X, n)) -> Vec(X, n)
f(X, n, xs) = xs

g :: Vec(Nat, 3)
g = f(Nat, 3, [1] =,= [2])
10 changes: 10 additions & 0 deletions brat/test/golden/error/even-length.brat.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Error in test/golden/error/even-length.brat@FC {start = Pos {line = 5, col = 15}, end = Pos {line = 5, col = 26}}:
g = f(Nat, 3, [1] =,= [2])
^^^^^^^^^^^

Expected vector of length 3
from the type: Vec(VApp VPar In checking_check_defs_1_g_1_ 0 B0, VPar In checking_check_defs_1_g_1_ 1)
but got vector: concatEqEven([1], [2])
of even length


2 changes: 1 addition & 1 deletion brat/test/golden/error/kbadvec3.brat.golden
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ constNil = { b => cons(1, nil) }
Expected vector of length 0
from the type: Vec(Bit, 0)
but got vector: [1]
of length (> 0)
with length (> 0)


5 changes: 5 additions & 0 deletions brat/test/golden/error/odd-length.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
f(X :: *, n :: #, Vec(X, n)) -> Vec(X, n)
f(X, n, xs) = xs

g :: Vec(Nat, 4)
g = f(Nat, 4, [1] =, 2 ,= [3])
10 changes: 10 additions & 0 deletions brat/test/golden/error/odd-length.brat.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Error in test/golden/error/odd-length.brat@FC {start = Pos {line = 5, col = 15}, end = Pos {line = 5, col = 30}}:
g = f(Nat, 4, [1] =, 2 ,= [3])
^^^^^^^^^^^^^^^

Expected vector of length 4
from the type: Vec(VApp VPar In checking_check_defs_1_g_1_ 0 B0, VPar In checking_check_defs_1_g_1_ 1)
but got vector: concatEqOdd([1], 2, [3])
of odd length


Loading