Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into refactor/parser-wc
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Dec 19, 2024
2 parents 859a144 + f0c22a4 commit 937ef01
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 108 deletions.
104 changes: 53 additions & 51 deletions brat/Brat/Compile/Hugr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,10 @@ runCheckingInCompile (Req _ _) = error "Compile monad found a command it can't h

-- To be called on top-level signatures which are already Inx-closed, but not
-- necessarily normalised.
compileSig :: Modey m -> CTy m Z -> Compile PolyFuncType
compileSig my cty = do
runCheckingInCompile (evalCTy S0 my cty) <&> compileCTy
compileSig :: Modey m -> CTy m Z -> Compile ([HugrType], [HugrType])
compileSig my cty = runCheckingInCompile (evalCTy S0 my cty) <&> (\(ss :->> ts) -> (compileRo ss, compileRo ts))

compileCTy (ss :->> ts )= PolyFuncType [] (FunctionType (compileRo ss) (compileRo ts))
compileCTy (ss :->> ts) = PolyFuncType [] (FunctionType (compileRo ss) (compileRo ts) bratExts)

compileRo :: Ro m i j -- The Ro that we're processing
-> [HugrType] -- The hugr type of the row
Expand Down Expand Up @@ -201,7 +200,7 @@ compileConst parent tm ty = do
constId <- addNode "Const" (OpConst (ConstOp parent (valFromSimple tm)))
loadId <- case ty of
HTFunc poly@(PolyFuncType [] _) ->
addNode "LoadFunction" (OpLoadFunction (LoadFunctionOp parent poly [] (FunctionType [] [HTFunc poly])))
addNode "LoadFunction" (OpLoadFunction (LoadFunctionOp parent poly [] (FunctionType [] [HTFunc poly] [])))
HTFunc (PolyFuncType _ _) -> error "Trying to compile function with type args"
_ -> addNode "LoadConst" (OpLoadConstant (LoadConstantOp parent ty))
addEdge (Port constId 0, Port loadId 0)
Expand All @@ -212,13 +211,13 @@ compileArithNode parent op TNat = addNode (show op ++ "_Nat") $ OpCustom $ case
Add -> binaryIntOp parent "iadd"
Sub -> binaryIntOp parent "isub"
Mul-> binaryIntOp parent "imul"
Div -> intOp parent "idiv_u" (FunctionType [hugrInt, hugrInt] [hugrInt]) [TANat intWidth, TANat intWidth]
Div -> intOp parent "idiv_u" [hugrInt, hugrInt] [hugrInt] [TANat intWidth, TANat intWidth]
Pow -> error "TODO: Pow" -- Not defined in extension
compileArithNode parent op TInt = addNode (show op ++ "_Int") $ OpCustom $ case op of
Add -> binaryIntOp parent "iadd"
Sub -> binaryIntOp parent "isub"
Mul-> binaryIntOp parent "imul"
Div -> intOp parent "idiv_u" (FunctionType [hugrInt, hugrInt] [hugrInt]) [TANat intWidth, TANat intWidth]
Div -> intOp parent "idiv_u" [hugrInt, hugrInt] [hugrInt] [TANat intWidth, TANat intWidth]
Pow -> error "TODO: Pow" -- Not defined in extension
compileArithNode parent op TFloat = addNode (show op ++ "_Float") $ OpCustom $ case op of
Add -> binaryFloatOp parent "fadd"
Expand Down Expand Up @@ -275,7 +274,7 @@ compileClauses parent ins ((matchData, rhs) :| clauses) = do
(ns, _) <- gets bratGraph
-- RHS has to be a box, so it must have a function type
outTys <- case nodeOuts (ns M.! rhs) of
[(_, VFun my cty)] -> compileSig my cty >>= (\(FunctionType _ outs) -> pure outs) . body
[(_, VFun my cty)] -> compileSig my cty >>= (\(_, outs) -> pure outs)
_ -> error "Expected 1 kernel function type from rhs"

-- Compile the match: testResult is the port holding the dynamic match result
Expand All @@ -295,13 +294,13 @@ compileClauses parent ins ((matchData, rhs) :| clauses) = do
didntMatch outTys parent ins = case nonEmpty clauses of
Just clauses -> compileClauses parent ins clauses
-- If there are no more clauses left to test, then the Hugr panics
Nothing -> let sig = FunctionType (snd <$> ins) outTys in
addNodeWithInputs "Panic" (OpCustom (CustomOp parent "brat" "panic" sig [])) ins outTys
Nothing -> let sig = FunctionType (snd <$> ins) outTys ["BRAT"] in
addNodeWithInputs "Panic" (OpCustom (CustomOp parent "BRAT" "panic" sig [])) ins outTys

didMatch :: [HugrType] -> NodeId -> [TypedPort] -> Compile [TypedPort]
didMatch outTys parent ins = gets bratGraph >>= \(ns,_) -> case ns M.! rhs of
BratNode (Box _venv src tgt) _ _ -> do
dfgId <- addNode "DidMatch_DFG" (OpDFG (DFG parent (FunctionType (snd <$> ins) outTys)))
dfgId <- addNode "DidMatch_DFG" (OpDFG (DFG parent (FunctionType (snd <$> ins) outTys bratExts)))
compileBox (src, tgt) dfgId
for_ (zip (fst <$> ins) (Port dfgId <$> [0..])) addEdge
pure $ zip (Port dfgId <$> [0..]) outTys
Expand Down Expand Up @@ -337,13 +336,13 @@ compileWithInputs parent name = gets compiled >>= (\case
let (funcDef, extra_call) = decls M.! name
nod <- if extra_call
then addNode ("direct_call(" ++ show funcDef ++ ")")
(OpCall (CallOp parent (FunctionType [] hTys)))
(OpCall (CallOp parent (FunctionType [] hTys bratExts)))
-- We are loading idNode as a value (not an Eval'd thing), and it is a FuncDef directly
-- corresponding to a Brat TLD (not that produces said TLD when eval'd)
else case hTys of
[HTFunc poly@(PolyFuncType [] _)] ->
addNode ("load_thunk(" ++ show funcDef ++ ")")
(OpLoadFunction (LoadFunctionOp parent poly [] (FunctionType [] [HTFunc poly])))
(OpLoadFunction (LoadFunctionOp parent poly [] (FunctionType [] [HTFunc poly] [])))
[HTFunc (PolyFuncType args _)] -> error $ unwords ["Unexpected type args to"
,show funcDef ++ ":"
,show args
Expand Down Expand Up @@ -376,7 +375,7 @@ compileWithInputs parent name = gets compiled >>= (\case
Splice (Ex outNode _) -> default_edges <$> do
ins <- compilePorts ins
outs <- compilePorts outs
let sig = FunctionType ins outs
let sig = FunctionType ins outs bratExts
case hasPrefix ["checking", "globals", "prim"] outNode of
-- If we're evaling a Prim, we add it directly into the kernel graph
Just suffix -> do
Expand All @@ -401,11 +400,12 @@ compileWithInputs parent name = gets compiled >>= (\case
let n = ext ++ ('_':op)
let [] = ins
let [(_, VFun Braty cty)] = outs
box_sig@(FunctionType inputTys outputTys) <- body <$> compileSig Braty cty
((Port loadConst _, _ty), ()) <- compileConstDfg parent n box_sig $ \dfg_id -> do
ins <- addNodeWithInputs ("Inputs" ++ n) (OpIn (InputNode dfg_id inputTys)) [] inputTys
outs <- addNodeWithInputs n (OpCustom (CustomOp dfg_id ext op box_sig [])) ins outputTys
addNodeWithInputs ("Outputs" ++ n) (OpOut (OutputNode dfg_id outputTys)) outs []
boxSig@(inputTys, outputTys) <- compileSig Braty cty
let boxFunTy = FunctionType inputTys outputTys bratExts
((Port loadConst _, _ty), ()) <- compileConstDfg parent n boxSig $ \dfgId -> do
ins <- addNodeWithInputs ("Inputs" ++ n) (OpIn (InputNode dfgId inputTys)) [] inputTys
outs <- addNodeWithInputs n (OpCustom (CustomOp dfgId ext op boxFunTy [])) ins outputTys
addNodeWithInputs ("Outputs" ++ n) (OpOut (OutputNode dfgId outputTys)) outs []
pure ()
pure $ default_edges loadConst

Expand All @@ -419,13 +419,13 @@ compileWithInputs parent name = gets compiled >>= (\case
-- Callee is a Prim node, insert Hugr Op; first look up outNode in the BRAT graph to get the Prim data
Just suffix -> default_edges <$> case M.lookup outNode ns of
Just (BratNode (Prim (ext,op)) _ _) -> do
addNode (show suffix) (OpCustom (CustomOp parent ext op (FunctionType ins outs) []))
addNode (show suffix) (OpCustom (CustomOp parent ext op (FunctionType ins outs [ext]) []))
x -> error $ "Expected a Prim node but got " ++ show x
Nothing -> case hasPrefix ["checking", "globals"] outNode of
-- Callee is a user-defined global def that, since it does not require an "extra" call, can be turned from IndirectCall to direct.
Just _ | (funcDef, False) <- fromJust (M.lookup outNode decls) -> do
callerId <- addNode ("direct_call(" ++ show funcDef ++ ")")
(OpCall (CallOp parent (FunctionType ins outs)))
(OpCall (CallOp parent (FunctionType ins outs bratExts)))
-- Add the static edge from the FuncDefn node to the port *after*
-- all of the dynamic arguments to the Call node.
-- This is because in hugr, static edges (like the graph arg to a
Expand All @@ -437,7 +437,7 @@ compileWithInputs parent name = gets compiled >>= (\case
_ -> compileWithInputs parent outNode >>= \case
Just calleeId -> do
callerId <- addNode ("indirect_call(" ++ show calleeId ++ ")")
(OpCallIndirect (CallIndirectOp parent (FunctionType ins outs)))
(OpCallIndirect (CallIndirectOp parent (FunctionType ins outs bratExts {-[]-})))
-- for an IndirectCall, the callee (thunk, function value) is the *first*
-- Hugr input. So move all the others along, and add that extra edge.
pure $ Just (callerId, 1, [(Port calleeId outPort, 0)])
Expand Down Expand Up @@ -468,11 +468,11 @@ compileWithInputs parent name = gets compiled >>= (\case
case outs of
[(_, VCon tycon _)] -> do
outs <- compilePorts outs
compileConstructor parent tycon c (FunctionType ins outs)
compileConstructor parent tycon c (FunctionType ins outs ["BRAT"])
PatternMatch cs -> default_edges <$> do
ins <- compilePorts ins
outs <- compilePorts outs
dfgId <- addNode "DidMatch_DFG" (OpDFG (DFG parent (FunctionType ins outs)))
dfgId <- addNode "DidMatch_DFG" (OpDFG (DFG parent (FunctionType ins outs bratExts)))
inputNode <- addNode "PatternMatch.Input" (OpIn (InputNode dfgId ins))
ccOuts <- compileClauses dfgId (zip (Port inputNode <$> [0..]) ins) cs
addNodeWithInputs "PatternMatch.Output" (OpOut (OutputNode dfgId (snd <$> ccOuts))) ccOuts []
Expand All @@ -483,7 +483,7 @@ compileWithInputs parent name = gets compiled >>= (\case
ins <- compilePorts ins
let [_, elemTy] = ins
outs <- compilePorts outs
let sig = FunctionType ins outs
let sig = FunctionType ins outs bratExts
addNode "Replicate" (OpCustom (CustomOp parent "BRAT" "Replicate" sig [TAType elemTy]))
x -> error $ show x ++ " should have been compiled outside of compileNode"

Expand Down Expand Up @@ -516,26 +516,28 @@ getOutPort parent p@(Ex srcNode srcPort) = do
-- Execute a compilation (which takes a DFG parent) in a nested monad;
-- produce a Const node containing the resulting Hugr, and a LoadConstant,
-- and return the latter.
compileConstDfg :: NodeId -> String -> FunctionType -> (NodeId -> Compile a) -> Compile (TypedPort, a)
compileConstDfg parent desc box_sig contents = do
compileConstDfg :: NodeId -> String -> ([HugrType], [HugrType]) -> (NodeId -> Compile a) -> Compile (TypedPort, a)
compileConstDfg parent desc (inTys, outTys) contents = do
st <- gets store
g <- gets bratGraph
-- First, we fork off a new namespace
(res, cs) <- desc -! do
((funTy, a), cs) <- desc -! do
ns <- gets nameSupply
pure $ flip runState (emptyCS g ns st) $ do
-- make a DFG node at the root. We can't use `addNode` since the
-- DFG needs itself as parent
dfg_id <- freshNode ("Box_" ++ show desc)
addOp (OpDFG $ DFG dfg_id box_sig) dfg_id
contents dfg_id
a <- contents dfg_id
let funTy = FunctionType inTys outTys bratExts
addOp (OpDFG $ DFG dfg_id funTy) dfg_id
pure (funTy, a)
let nestedHugr = renameAndSortHugr (nodes cs) (edges cs)
let ht = HTFunc $ PolyFuncType [] box_sig
let ht = HTFunc $ PolyFuncType [] funTy

constNode <- addNode ("ConstTemplate_" ++ desc) (OpConst (ConstOp parent (HVFunction nestedHugr)))
lcPort <- head <$> addNodeWithInputs ("LoadTemplate_" ++ desc) (OpLoadConstant (LoadConstantOp parent ht))
[(Port constNode 0, ht)] [ht]
pure (lcPort, res)
pure (lcPort, a)

-- Brat computations may capture some local variables. Thus, we need
-- to lambda-lift, producing (as results) a Partial node and a list of
Expand All @@ -549,12 +551,12 @@ compileBratBox parent name (venv, src, tgt) cty = do
parmTys <- compileGraphTypes (map (binderToValue Braty . snd) params)

-- Create a FuncDefn for the lambda that takes the params as first inputs
(FunctionType inputTys outputTys) <- body <$> compileSig Braty cty
(inputTys, outputTys) <- compileSig Braty cty
let allInputTys = parmTys ++ inputTys
let box_sig = FunctionType allInputTys outputTys
let boxInnerSig = FunctionType allInputTys outputTys bratExts

(templatePort, _) <- compileConstDfg parent ("BB" ++ show name) box_sig $ \dfg_id -> do
src_id <- addNode ("LiftedCapturesInputs" ++ show name) (OpIn (InputNode dfg_id allInputTys))
(templatePort, _) <- compileConstDfg parent ("BB" ++ show name) (allInputTys, outputTys) $ \dfgId -> do
src_id <- addNode ("LiftedCapturesInputs" ++ show name) (OpIn (InputNode dfgId allInputTys))
-- Now map ports in the BRAT Graph to their Hugr equivalents.
-- Each captured value is read from an element of src_id, starting from 0
let lifted = [(src, Port src_id i) | ((src, _ty), i) <- zip params [0..]]
Expand All @@ -563,10 +565,10 @@ compileBratBox parent name (venv, src, tgt) cty = do
st <- get
put $ st {liftedOutPorts = M.fromList lifted}
-- no need to return any holes
compileWithInputs dfg_id tgt
compileWithInputs dfgId tgt

-- Finally, we add a `Partial` node to supply the captured params.
partialNode <- addNode "Partial" (OpCustom $ partialOp parent box_sig (length params))
partialNode <- addNode "Partial" (OpCustom $ partialOp parent boxInnerSig (length params))
addEdge (fst templatePort, Port partialNode 0)
edge_srcs <- for (map fst params) $ getOutPort parent
pure (partialNode, zip (map fromJust edge_srcs) [1..])
Expand All @@ -577,9 +579,9 @@ compileKernBox parent name contents cty = do
-- compile kernel nodes only into a Hugr with "Holes"
-- when we see a Splice, we'll record the func-port onto a list
-- return a Hugr with holes
box_sig <- body <$> compileSig Kerny cty
let box_ty = HTFunc $ PolyFuncType [] box_sig
(templatePort, holelist) <- compileConstDfg parent ("KB" ++ show name) box_sig $ \dfg_id -> do
boxInnerSig@(inTys, outTys) <- compileSig Kerny cty
let boxTy = HTFunc $ PolyFuncType [] (FunctionType inTys outTys bratExts)
(templatePort, holelist) <- compileConstDfg parent ("KB" ++ show name) boxInnerSig $ \dfg_id -> do
contents dfg_id
gets holes

Expand All @@ -591,11 +593,11 @@ compileKernBox parent name contents cty = do
ins <- compilePorts ins
outs <- compilePorts outs
kernel_src <- compileWithInputs parent kernel_src <&> fromJust
pure (Port kernel_src port, HTFunc (PolyFuncType [] (FunctionType ins outs))))
pure (Port kernel_src port, HTFunc (PolyFuncType [] (FunctionType ins outs bratExts))))

-- Add a substitute node to fill the holes in the template
let hole_sigs = [ body poly | (_, HTFunc poly) <- hole_ports ]
head <$> addNodeWithInputs ("subst_" ++ show name) (OpCustom (substOp parent box_sig hole_sigs)) (templatePort : hole_ports) [box_ty]
head <$> addNodeWithInputs ("subst_" ++ show name) (OpCustom (substOp parent (FunctionType inTys outTys bratExts) hole_sigs)) (templatePort : hole_ports) [boxTy]


-- We get a bunch of TypedPorts which are associated with Srcs in the BRAT graph.
Expand Down Expand Up @@ -739,7 +741,7 @@ makeConditional parent discrim otherInputs cases = do
outId <- addNode ("Output" ++ name) (OpOut (OutputNode caseId outTys))
for_ (zip (fst <$> outs) (Port outId <$> [0..])) addEdge

addOp (OpCase (ix, Case parent (FunctionType tys outTys))) caseId
addOp (OpCase (ix, Case parent (FunctionType tys outTys bratExts))) caseId
pure outTys

allRowsEqual :: [[HugrType]] -> Bool
Expand All @@ -753,7 +755,7 @@ compilePrimTest :: NodeId
-> Compile TypedPort
compilePrimTest parent (port, ty) (PrimCtorTest c tycon unpackingNode outputs) = do
let sumOut = HTSum (SG (GeneralSum [[ty], snd <$> outputs]))
let sig = FunctionType [ty] [sumOut]
let sig = FunctionType [ty] [sumOut] ["BRAT"]
testId <- addNode ("PrimCtorTest " ++ show c)
(OpCustom (CustomOp
parent
Expand All @@ -771,7 +773,7 @@ compilePrimTest parent port@(_, ty) (PrimLitTest tm) = do
[(Port constId 0, ty)] [ty]
-- Connect to a test node
let sumOut = HTSum (SG (GeneralSum [[ty], []]))
let sig = FunctionType [ty, ty] [sumOut]
let sig = FunctionType [ty, ty] [sumOut] ["BRAT"]
head <$> addNodeWithInputs ("PrimLitTest " ++ show tm)
(OpCustom (CustomOp parent "BRAT" ("PrimLitTest::" ++ show ty) sig []))
[port, loadPort]
Expand All @@ -786,7 +788,7 @@ undoPrimTest :: NodeId
-> PrimTest HugrType -- The test to undo
-> Compile TypedPort
undoPrimTest parent inPorts outTy (PrimCtorTest c tycon _ _) = do
let sig = FunctionType (snd <$> inPorts) [outTy]
let sig = FunctionType (snd <$> inPorts) [outTy] ["BRAT"]
head <$> addNodeWithInputs
("UndoCtorTest " ++ show c)
(constructorOp parent tycon c sig)
Expand Down Expand Up @@ -833,16 +835,16 @@ compileModule venv = do
[(Ex input 0, _)] | Just (BratNode (Box _ src tgt) _ outs) <- M.lookup input ns ->
case outs of
[(_, VFun Braty cty)] -> do
sig <- compileSig Braty cty
pure (sig, False, compileBox (src, tgt))
(inTys, outTys) <- compileSig Braty cty
pure (PolyFuncType [] (FunctionType inTys outTys bratExts), False, compileBox (src, tgt))
[(_, VFun Kerny cty)] -> do
-- We're compiling, e.g.
-- f :: { Qubit -o Qubit }
-- f = { h; circ(pi) }
-- Although this looks like a constant kernel, we'll have to compile the
-- computation that produces this constant. We do so by making a FuncDefn
-- that takes no arguments and produces the constant kernel graph value.
thunkTy <- HTFunc <$> compileSig Kerny cty
thunkTy <- HTFunc . PolyFuncType [] . (\(ins, outs) -> FunctionType ins outs bratExts) <$> compileSig Kerny cty
pure (funcReturning [thunkTy], True, \parent ->
withIO parent thunkTy $ compileKernBox parent input (compileBox (src, tgt)) cty)
_ -> error "Box should have exactly one output of Thunk type"
Expand Down Expand Up @@ -873,7 +875,7 @@ compileModule venv = do
pure id

funcReturning :: [HugrType] -> PolyFuncType
funcReturning outs = PolyFuncType [] (FunctionType [] outs)
funcReturning outs = PolyFuncType [] (FunctionType [] outs bratExts)

compileNoun :: [HugrType] -> [OutPort] -> NodeId -> Compile ()
compileNoun outs srcPorts parent = do
Expand Down
Loading

0 comments on commit 937ef01

Please sign in to comment.