Skip to content

Commit

Permalink
chore: Update compilation to target hugr v0.6.0 (#52)
Browse files Browse the repository at this point in the history
Based off #51. Closes #5.
* Get rid of tupling in places where it's no longer required. We can
pass around rows more freely now
* Add LoadFunction op and use it in place of LoadConst when appropriate
* Replace `HugrConst` by adding `HugrValue` and `CustomConst` which
wraps values
  • Loading branch information
croyzor authored Nov 5, 2024
1 parent c1110ed commit 8a4ff4f
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 119 deletions.
55 changes: 28 additions & 27 deletions brat/Brat/Compile/Hugr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,20 @@ compileType TBool = HTSum (SU (UnitSum 2))
compileType TInt = hugrInt
compileType TNat = hugrInt
compileType TFloat = hugrFloat
compileType ty@(TCons _ _) = HTTuple (tuple ty)
compileType ty@(TCons _ _) = htTuple (tuple ty)
where
tuple :: Val n -> [HugrType]
tuple (TCons hd rest) = (compileType hd):(tuple rest)
tuple TNil = []
tuple ty = error $ "Found " ++ show ty ++ " in supposed tuple type"
compileType TNil = HTTuple []
compileType TNil = htTuple []
compileType (VSum my ros) = case my of
Braty -> error "Todo: compileTypeWorker for BRAT"
Kerny -> HTSum (SG (GeneralSum $ map (\(Some ro) -> HTTuple (compileRo ro)) ros))
Kerny -> HTSum (SG (GeneralSum $ map (\(Some ro) -> compileRo ro) ros))
compileType (TVec el _) = hugrList (compileType el)
compileType (TList el) = hugrList (compileType el)
-- All variables are of kind `TypeFor m xs`, we already checked in `kindCheckRow`
compileType (VApp _ _) = HTTuple []
compileType (VApp _ _) = htTuple []
-- VFun is already evaluated here, so we don't need to call `compileSig`
compileType (VFun _ cty) = HTFunc $ compileCTy cty
compileType ty = error $ "todo: compile type " ++ show ty
Expand All @@ -198,8 +198,12 @@ registerCompiled from to = do

compileConst :: NodeId -> SimpleTerm -> HugrType -> Compile NodeId
compileConst parent tm ty = do
constId <- addNode "Const" (OpConst (ConstOp parent (constFromSimple tm) ty))
loadId <- addNode "LoadConst" (OpLoadConstant (LoadConstantOp parent ty))
constId <- addNode "Const" (OpConst (ConstOp parent (valFromSimple tm)))
loadId <- case ty of
HTFunc poly@(PolyFuncType [] _) ->
addNode "LoadFunction" (OpLoadFunction (LoadFunctionOp parent poly [] (FunctionType [] [HTFunc poly])))
HTFunc (PolyFuncType _ _) -> error "Trying to compile function with type args"
_ -> addNode "LoadConst" (OpLoadConstant (LoadConstantOp parent ty))
addEdge (Port constId 0, Port loadId 0)
pure loadId

Expand Down Expand Up @@ -336,8 +340,12 @@ compileWithInputs parent name = gets compiled <&> M.lookup name >>= \case
(OpCall (CallOp parent (FunctionType [] hTys)))
-- We are loading idNode as a value (not an Eval'd thing), and it is a FuncDef directly
-- corresponding to a Brat TLD (not that produces said TLD when eval'd)
False -> addNode ("load_thunk(" ++ show funcDef ++ ")")
(OpLoadConstant (LoadConstantOp parent (let [ty] = hTys in ty)))
False -> case hTys of
[HTFunc poly@(PolyFuncType [] _)] ->
addNode ("load_thunk(" ++ show funcDef ++ ")")
(OpLoadFunction (LoadFunctionOp parent poly [] (FunctionType [] [HTFunc poly])))
[HTFunc (PolyFuncType args _)] -> error $ "Unexpected type args to " ++ show funcDef ++ ": " ++ show args
_ -> error $ "Expected a function argument when loading thunk, got: " ++ show hTys
-- the only input
pure $ Just (nod, [(Port funcDef 0, 0)])
compileNode in_edges = do
Expand Down Expand Up @@ -479,12 +487,10 @@ compileWithInputs parent name = gets compiled <&> M.lookup name >>= \case
compileConstructor :: NodeId -> UserName -> UserName -> FunctionType -> Compile NodeId
compileConstructor parent tycon con sig
| Just b <- isBool con = do
-- A boolean value is a tuple and a tag
-- This is the same thing that happens in Brat.Checker.Clauses (makeDiscriminator)
makeTuple <- addNode "bool.MakeTuple" (OpMakeTuple (MakeTupleOp parent []))
tag <- addNode "bool.tag" (OpTag (TagOp parent (if b then 1 else 0) [HTTuple [], HTTuple []]))
addEdge (Port makeTuple 0, Port tag 0)
pure tag
-- A boolean value is a tag which takes no inputs and produces an empty tuple
-- This is the same thing that happens in Brat.Checker.Clauses to make the
-- discriminator (makeRowTag)
addNode "bool.tag" (OpTag (TagOp parent (if b then 1 else 0) [[], []]))
| otherwise = let name = "Constructor " ++ show tycon ++ "::" ++ show con in
addNode name (constructorOp parent tycon con sig)
where
Expand Down Expand Up @@ -523,7 +529,7 @@ compileConstDfg parent desc box_sig contents = do
let nestedHugr = renameAndSortHugr (nodes cs) (edges cs)
let ht = HTFunc $ PolyFuncType [] box_sig

constNode <- addNode ("ConstTemplate_" ++ desc) (OpConst (ConstOp parent (HCFunction nestedHugr) ht))
constNode <- addNode ("ConstTemplate_" ++ desc) (OpConst (ConstOp parent (HVFunction nestedHugr)))
lcPort <- head <$> addNodeWithInputs ("LoadTemplate_" ++ desc) (OpLoadConstant (LoadConstantOp parent ht))
[(Port constNode 0, ht)] [ht]
pure (lcPort, res)
Expand Down Expand Up @@ -683,17 +689,12 @@ compileMatchSequence parent portTable (MatchSequence {..}) = do
makeRowTag "DidNotMatch" parent 0 sumTy ins

makeRowTag :: String -> NodeId -> Int -> SumOfRows -> [TypedPort] -> Compile [TypedPort]
makeRowTag hint parent tag sor@(SoR sumRows) ins = assert (sumRows !! tag == (snd <$> ins)) $ do
tuple <- addNodeWithInputs (hint ++ "_MakeTuple") (OpMakeTuple (MakeTupleOp parent (snd <$> ins))) ins [HTTuple (snd <$> ins)]
addNodeWithInputs (hint ++ "_Tag") (OpTag (TagOp parent tag (HTTuple <$> sumRows))) tuple [compileSumOfRows sor]
makeRowTag hint parent tag sor@(SoR sumRows) ins = assert (sumRows !! tag == (snd <$> ins)) $
addNodeWithInputs (hint ++ "_Tag") (OpTag (TagOp parent tag sumRows)) ins [compileSumOfRows sor]

getSumVariants :: HugrType -> [[HugrType]]
getSumVariants (HTSum (SU (UnitSum n))) = replicate n []
getSumVariants (HTSum (SG (GeneralSum rows))) = fromTuple <$> rows
where
fromTuple :: HugrType -> [HugrType]
fromTuple (HTTuple row) = row
fromTuple _ = error "Expected row of tuples in getSumVariants"
getSumVariants (HTSum (SG (GeneralSum rows))) = rows
getSumVariants ty = error $ "Expected a sum type, got " ++ show ty


Expand Down Expand Up @@ -748,7 +749,7 @@ compilePrimTest :: NodeId
-> PrimTest HugrType -- The test to run
-> Compile TypedPort
compilePrimTest parent (port, ty) (PrimCtorTest c tycon unpackingNode outputs) = do
let sumOut = (HTSum (SG (GeneralSum [HTTuple [ty], HTTuple (snd <$> outputs)])))
let sumOut = (HTSum (SG (GeneralSum [[ty], snd <$> outputs])))
let sig = FunctionType [ty] [sumOut]
testId <- addNode ("PrimCtorTest " ++ show c)
(OpCustom (CustomOp
Expand All @@ -762,11 +763,11 @@ compilePrimTest parent (port, ty) (PrimCtorTest c tycon unpackingNode outputs) =
pure (Port testId 0, sumOut)
compilePrimTest parent port@(_, ty) (PrimLitTest tm) = do
-- Make a Const node that holds the value we test against
constId <- addNode "LitConst" (OpConst (ConstOp parent (constFromSimple tm) ty))
constId <- addNode "LitConst" (OpConst (ConstOp parent (valFromSimple tm)))
loadPort <- head <$> addNodeWithInputs "LitLoad" (OpLoadConstant (LoadConstantOp parent ty))
[(Port constId 0, ty)] [ty]
-- Connect to a test node
let sumOut = HTSum (SG (GeneralSum [HTTuple [ty], HTTuple []]))
let sumOut = HTSum (SG (GeneralSum [[ty], []]))
let sig = FunctionType [ty, ty] [sumOut]
head <$> addNodeWithInputs ("PrimLitTest " ++ show tm)
(OpCustom (CustomOp parent "BRAT" ("PrimLitTest::" ++ show ty) sig []))
Expand All @@ -790,7 +791,7 @@ undoPrimTest parent inPorts outTy (PrimCtorTest c tycon _ _) = do
[outTy]
undoPrimTest parent inPorts outTy (PrimLitTest tm) = do
unless (null inPorts) $ error "Unexpected inPorts"
constId <- addNode "LitConst" (OpConst (ConstOp parent (constFromSimple tm) outTy))
constId <- addNode "LitConst" (OpConst (ConstOp parent (valFromSimple tm)))
head <$> addNodeWithInputs "LitLoad" (OpLoadConstant (LoadConstantOp parent outTy))
[(Port constId 0, outTy)] [outTy]

Expand Down
Loading

0 comments on commit 8a4ff4f

Please sign in to comment.