Skip to content

Commit

Permalink
[new] Allow pattern matching in lambdas, and compile constructors (#22)
Browse files Browse the repository at this point in the history
Allow lambda expressions to have multiple clauses employing pattern
matching logic just like what is done for top level definitions. This
improves expressiveness as previously we couldn't pattern match in
kernels that were defined as lambdas created by classical functions.

These multi-clause lambda expressions - "multilambdas" - are parsed
inside curly braces, as many lambda expressions separated by a `|`.

Closes #8, #17

---------

Co-authored-by: Alan Lawrence <[email protected]>
  • Loading branch information
croyzor and acl-cqc authored Aug 7, 2024
1 parent ffa0c53 commit ed58724
Show file tree
Hide file tree
Showing 30 changed files with 424 additions and 301 deletions.
195 changes: 165 additions & 30 deletions brat/Brat/Checker.hs
Original file line number Diff line number Diff line change
@@ -1,37 +1,29 @@
module Brat.Checker (check
module Brat.Checker (checkBody
,check
,run
,VEnv
,Checking
,Graph
,Modey(..)
,Node
,CheckingSig(..)
,TypedHole(..)
,wrapError
,next, knext
,localFC
,emptyEnv
,checkInputs, checkOutputs, checkThunk
,CheckConstraints
,TensorOutputs(..)
,kindCheck, kindCheckRow, kindCheckAnnotation
,mkArgRo
,weaken
,kindCheck
,kindCheckAnnotation
,kindCheckRow
,tensor
) where

import Control.Arrow (first)
import Control.Monad (foldM)
import Control.Monad.Freer
import Data.Bifunctor (second)
import Data.Functor (($>), (<&>))
-- import Data.List (filter, intercalate, transpose)
import Data.List ((\\))
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NE
import qualified Data.Map as M
import Data.Maybe (fromJust)
import Data.Type.Equality ((:~:)(..))
import Prelude hiding (filter)

import Brat.Checker.Helpers
import Brat.Checker.Monad
import Brat.Checker.Quantity
import Brat.Checker.SolvePatterns (argProblems, argProblemsWithLeftovers, solve)
import Brat.Checker.Types
import Brat.Constructors
import Brat.Error
Expand All @@ -41,14 +33,17 @@ import qualified Brat.FC as FC
import Brat.Graph
import Brat.Naming
-- import Brat.Search
import Brat.Syntax.Abstractor (NormalisedAbstractor(..), normaliseAbstractor)
import Brat.Syntax.Common
import Brat.Syntax.Core
import Brat.Syntax.Port (toEnd)
import Brat.Syntax.FuncDecl (FunBody(..))
import Brat.Syntax.Port (ToEnd, toEnd)
import Brat.Syntax.Simple
import Brat.Syntax.Value
import Brat.UserName
import Bwd
import Hasochism
import Util (zip_same_length)

-- Put things into a standard form in a kind-directed manner, such that it is
-- meaningful to do case analysis on them
Expand All @@ -60,9 +55,6 @@ standardise k val = eval S0 val <&> (k,) >>= \case
mergeEnvs :: [Env a] -> Checking (Env a)
mergeEnvs = foldM combineDisjointEnvs M.empty

emptyEnv :: Env a
emptyEnv = M.empty

singletonEnv :: (?my :: Modey m) => String -> (Src, BinderType m) -> Checking (Env (EnvData m))
singletonEnv x input@(p, ty) = case ?my of
Braty -> pure $ M.singleton (plain x) [(p, ty)]
Expand Down Expand Up @@ -187,7 +179,8 @@ checkThunk m name cty tm = do
check :: (CheckConstraints m k
,EvMode m
,TensorOutputs (Outputs m d)
,?my :: Modey m)
,?my :: Modey m
, DIRY d)
=> WC (Term d k)
-> ChkConnectors m d k
-> Checking (SynConnectors m d k
Expand All @@ -198,7 +191,8 @@ check' :: forall m d k
. (CheckConstraints m k
,EvMode m
,TensorOutputs (Outputs m d)
,?my :: Modey m)
,?my :: Modey m
, DIRY d)
=> Term d k
-> ChkConnectors m d k
-> Checking (SynConnectors m d k
Expand All @@ -218,11 +212,85 @@ check' (s :-: t) (overs, unders) = do
pure ((ins, outs), (rightovers, rightunders))
check' Pass ([], ()) = typeErr "pass is being given an empty row"
check' Pass (overs, ()) = pure (((), overs), ([], ()))
check' (Lambda (binder, body) []) (overs, unders) = do
(ext, overs) <- abstract overs (unWC binder)
(sycs, ((), unders)) <- localEnv ext $ check body ((), unders)
pure (sycs, (overs, unders))
check' (Lambda (_binder, _body) _clauses) (_overs, _unders) = error "Multi clause lambda doesn't check yet"
check' (Lambda c@(WC abstFC abst, body) cs) (overs, unders) = do
-- Used overs have their port pulling taken care of
(problem, rightOverSrcs) <- localFC abstFC $ argProblemsWithLeftovers (fst <$> overs) (normaliseAbstractor abst) []
-- That processes the whole abstractor, so all of the overs in the
-- `Problem` it creates are used
let usedOvers = [ (src, fromJust (lookup src overs)) | (src, _) <- problem ]
let rightOvers = [ over | over@(src,_) <- overs, src `elem` rightOverSrcs ]
case diry @d of
Chky -> do
-- We'll check the first variant against a Hypo node (omitted from compilation)
-- to work out how many overs/unders it needs, and then check it again (in Chk)
-- with the other clauses, as part of the body.
(ins :->> outs) <- mkSig usedOvers unders
(allFakeUnders, rightFakeUnders, tgtMap) <- suppressHoles $ suppressGraph $ do
(_, [], fakeOvers, fakeAcc) <- anext "lambda_fake_source" Hypo (S0, Some (Zy :* S0)) R0 ins
-- Hypo `check` calls need an environment, even just to compute leftovers;
-- we get that env by solving `problem` reformulated in terms of the `fakeOvers`
let srcMap = fromJust $ zip_same_length (fst <$> usedOvers) (fst <$> fakeOvers)
let fakeProblem = [ (fromJust (lookup src srcMap), pat) | (src, pat) <- problem ]
fakeEnv <- localFC abstFC $ solve ?my fakeProblem >>= (solToEnv . snd)
localEnv fakeEnv $ do
(_, fakeUnders, [], _) <- anext "lambda_fake_target" Hypo fakeAcc outs R0
Just tgtMap <- pure $ zip_same_length (fst <$> fakeUnders) unders
(((), ()), ((), rightFakeUnders)) <- check body ((), fakeUnders)
pure (fakeUnders, rightFakeUnders, tgtMap)

let usedFakeUnders = (fst <$> allFakeUnders) \\ (fst <$> rightFakeUnders)
let usedUnders = [ fromJust (lookup tgt tgtMap) | tgt <- usedFakeUnders ]
let rightUnders = [ fromJust (lookup tgt tgtMap) | (tgt, _) <- rightFakeUnders ]
sig <- mkSig usedOvers usedUnders
patOuts <- checkClauses sig usedOvers (c :| cs)
mkWires patOuts usedUnders
pure (((), ()), (rightOvers, rightUnders))
Syny -> do
synthOuts <- suppressHoles $ suppressGraph $ do
env <- localFC abstFC $
argProblems (fst <$> usedOvers) (normaliseAbstractor abst) [] >>=
solve ?my >>=
(solToEnv . snd)
(((), synthOuts), ((), ())) <- localEnv env $ check body ((), ())
pure synthOuts

sig <- mkSig usedOvers synthOuts
patOuts <- checkClauses sig usedOvers ((fst c, WC (fcOf body) (Emb body)) :| cs)
pure (((), patOuts), (rightOvers, ()))
where
-- Invariant: When solToEnv is called, port pulling has already been resolved,
-- because that's one of the functions of `argProblems`.
--
-- N.B.: Here we update the port names to be the user variable names for nicer
-- error messages. This mirrors previous behaviour using `abstract`, but is a
-- bit of a hack. See issue #23.
solToEnv :: [(String, (Src, BinderType m))] -> Checking (M.Map UserName (EnvData m))
solToEnv xs = traverse (uncurry singletonEnv) (portNamesToBoundNames xs) >>= mergeEnvs

portNamesToBoundNames :: [(String, (Src, BinderType m))] -> [(String, (Src, BinderType m))]
portNamesToBoundNames = fmap (\(n, (src, ty)) -> (n, (NamedPort (end src) n, ty)))

mkSig :: ToEnd t => [(Src, BinderType m)] -> [(NamedPort t, BinderType m)] -> Checking (CTy m Z)
mkSig overs unders = rowToRo ?my (retuple <$> overs) S0 >>=
\(Some (inRo :* endz)) -> rowToRo ?my (retuple <$> unders) endz >>=
\(Some (outRo :* _)) -> pure (inRo :->> outRo)

retuple (NamedPort e p, ty) = (p, e, ty)

mkWires overs unders = case zip_same_length overs unders of
Nothing -> err $ InternalError "Trying to wire up different sized lists of wires"
Just conns -> traverse (\((src, ty), (tgt, _)) -> wire (src, binderToValue ?my ty, tgt)) conns

checkClauses cty@(ins :->> outs) overs all_cs = do
let clauses = NE.zip (NE.fromList [0..]) all_cs <&>
\(i, (abs, tm)) -> Clause i (normaliseAbstractor <$> abs) tm
clauses <- traverse (checkClause ?my "lambda" cty) clauses
(_, patMatchUnders, patMatchOvers, _) <- anext "lambda" (PatternMatch clauses) (S0, Some (Zy :* S0))
ins
outs
mkWires overs patMatchUnders
pure patMatchOvers

check' (Pull ports t) (overs, unders) = do
unders <- pullPortsRow ports unders
check t (overs, unders)
Expand Down Expand Up @@ -419,6 +487,73 @@ check' (Simple tm) ((), ((hungry, ty):unders)) = do
pure (((), ()), ((), unders))
check' tm _ = error $ "check' " ++ show tm


-- Clauses from either function definitions or case statements, as we get
-- them from the elaborator
data Clause = Clause
{ index :: Int -- Which clause is this (in the order they're defined in source)
, lhs :: WC NormalisedAbstractor
, rhs :: WC (Term Chk Noun)
}
deriving Show

-- Return the tests that need to succeed for this clause to fire
-- (Tests are always defined on the overs of the outer box, rather than on
-- refined overs)
checkClause :: forall m. (CheckConstraints m UVerb, EvMode m) => Modey m
-> String
-> CTy m Z
-> Clause
-> Checking
( TestMatchData m -- TestMatch data (LHS)
, Name -- Function node (RHS)
)
checkClause my fnName cty clause = modily my $ do
let clauseName = fnName ++ "." ++ show (index clause)

-- First, we check the patterns on the LHS. This requires some overs,
-- so we make a box, however this box will be skipped during compilation.
(vars, match, rhsCty) <- suppressHoles . fmap snd $
let ?my = my in makeBox (clauseName ++ "_setup") cty $
\(overs, unders) -> do
-- Make a problem to solve based on the lhs and the overs
problem <- argProblems (fst <$> overs) (unWC $ lhs clause) []
(tests, sol) <- localFC (fcOf (lhs clause)) $ solve my problem
-- The solution gives us the variables bound by the patterns.
-- We turn them into a row
Some (patEz :* patRo) <- mkArgRo my S0 ((\(n, (src, ty)) -> (NamedPort (toEnd src) n, ty)) <$> sol)
-- Also make a row for the refined outputs (shifted by the pattern environment)
Some (_ :* outRo) <- mkArgRo my patEz (first (fmap toEnd) <$> unders)
let match = TestMatchData my $ MatchSequence overs tests (snd <$> sol)
let vars = fst <$> sol
pure (vars, match, patRo :->> outRo)

-- Now actually make a box for the RHS and check it
((boxPort, _ty), _) <- let ?my = my in makeBox (clauseName ++ "_rhs") rhsCty $ \(rhsOvers, rhsUnders) -> do
let abstractor = foldr ((:||:) . APat . Bind) AEmpty vars
let ?my = my in do
env <- abstractAll rhsOvers abstractor
localEnv env $ check @m (rhs clause) ((), rhsUnders)
let NamedPort {end=Ex rhsNode _} = boxPort
pure (match, rhsNode)

-- Top level function for type checking function definitions
-- Will make a top-level box for the function, then type check the definition
checkBody :: (CheckConstraints m UVerb, EvMode m, ?my :: Modey m)
=> String -- The function name
-> FunBody Term UVerb
-> CTy m Z -- Function type
-> Checking Src
checkBody fnName body cty = case body of
NoLhs tm -> do
((src, _), _) <- makeBox (fnName ++ ".box") cty $ \(overs, unders) -> check tm (overs, unders)
pure src
Clauses (c :| cs) -> do
fc <- req AskFC
((box, _), _) <- makeBox (fnName ++ ".box") cty (check (WC fc (Lambda c cs)))
pure box
Undefined -> err (InternalError "Checking undefined clause")

-- Constructs row from a list of ends and types. Uses standardize to ensure that dependency is
-- detected. Fills in the first bot ends from a stack. The stack grows every time we go under
-- a binder. The final stack is returned, so we can compute an output row after an input row.
Expand Down
2 changes: 1 addition & 1 deletion brat/Brat/Checker/Helpers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ pullPorts toPort showFn (p:ports) types = do
pull1Port :: PortName
-> [(a, ty)]
-> Checking ((a, ty), [(a, ty)])
pull1Port p [] = fail $ "Port not found: " ++ p
pull1Port p [] = fail $ "Port not found: " ++ p ++ " in " ++ showFn types
pull1Port p (x@(a,_):xs)
| p == toPort a
= if (p `elem` (toPort . fst <$> xs))
Expand Down
23 changes: 18 additions & 5 deletions brat/Brat/Checker/Monad.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Brat.Checker.Monad where

import Brat.Checker.Quantity (Quantity(..), qpred)
import Brat.Checker.Quantity (Quantity(..))
import Brat.Checker.Types hiding (HoleData(..))
import Brat.Constructors (ConstructorMap, CtorArgs)
import Brat.Error (Error(..), ErrorMsg(..), dumbErr)
Expand Down Expand Up @@ -118,7 +118,7 @@ localVEnv ext (Req AskVEnv k) = do env <- req AskVEnv
localVEnv ext (k (env { locals = M.union ext (locals env) }))
localVEnv ext (Req (InLvl str c) k) = Req (InLvl str (localVEnv ext c)) (localVEnv ext . k)
localVEnv ext (Req r k) = Req r (localVEnv ext . k)

-- runs a computation, but intercepts uses of outer *locals* variables and redirects
-- them to use new outports of the specified node (expected to be a Source).
-- Returns a list of captured variables and their generated (Source-node) outports
Expand Down Expand Up @@ -184,9 +184,9 @@ lookupAndUse :: UserName -> KEnv
-> Either Error (Maybe ((Src, BinderType Kernel), KEnv))
lookupAndUse x kenv = case M.lookup x kenv of
Nothing -> Right Nothing
Just (q, rest) -> case qpred q of
Nothing -> Left $ dumbErr $ TypeErr $ (show x) ++ " has already been used"
Just q -> Right $ Just (rest, M.insert x (q, rest) kenv)
Just (None, _) -> Left $ dumbErr $ TypeErr $ (show x) ++ " has already been used"
Just (One, rest) -> Right $ Just (rest, M.insert x (None, rest) kenv)
Just (Tons, rest) -> Right $ Just (rest, M.insert x (Tons, rest) kenv)

localKVar :: KEnv -> Checking v -> Checking v
localKVar _ (Ret v) = Ret v
Expand Down Expand Up @@ -305,3 +305,16 @@ instance FreshMonad Checking where
-- This way we get file contexts when pattern matching fails
instance MonadFail Checking where
fail = typeErr

-- Run a computation without logging any holes
suppressHoles :: Checking a -> Checking a
suppressHoles (Ret x) = Ret x
suppressHoles (Req (LogHole _) k) = suppressHoles (k ())
suppressHoles (Req c k) = Req c (suppressHoles . k)

-- Run a computation without doing any graph generation
suppressGraph :: Checking a -> Checking a
suppressGraph (Ret x) = Ret x
suppressGraph (Req (AddNode _ _) k) = suppressGraph (k ())
suppressGraph (Req (Wire _) k) = suppressGraph (k ())
suppressGraph (Req c k) = Req c (suppressGraph . k)
7 changes: 1 addition & 6 deletions brat/Brat/Checker/Quantity.hs
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
module Brat.Checker.Quantity where

data Quantity = None | One | Tons deriving Show

qpred :: Quantity -> Maybe Quantity
qpred None = Nothing
qpred One = Just None
qpred Tons = Just Tons
data Quantity = None | One | Tons deriving (Enum, Show)
Loading

0 comments on commit ed58724

Please sign in to comment.