Skip to content

Commit

Permalink
working on case expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
gnumonik committed Mar 26, 2024
1 parent 3b24eef commit 3f8183d
Showing 1 changed file with 97 additions and 18 deletions.
115 changes: 97 additions & 18 deletions src/Language/PureScript/CoreFn/Convert/ToPIR.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
module Language.PureScript.CoreFn.Convert.ToPIR where

import Prelude
import Language.PureScript.Names (Qualified (..), ProperName(..), runIdent, pattern ByThisModuleName)
import Language.PureScript.Names (Qualified (..), ProperName(..), runIdent, pattern ByThisModuleName, disqualify)
import Language.PureScript.CoreFn.FromJSON ()
import Data.Text qualified as T
import Language.PureScript.PSString (prettyPrintString)
Expand Down Expand Up @@ -46,7 +46,8 @@ import PlutusIR.Error
import Control.Monad.Reader
import PlutusCore qualified as PLC
import Control.Exception

import Data.List (sortOn)
import Control.Lens ((&),(.~),ix)
type PLCProgram uni fun a = PLC.Program PLC.TyName PLC.Name uni fun (Provenance a)

fuckThisMonadStack ::
Expand Down Expand Up @@ -137,6 +138,9 @@ genFreshName = do
i <- getCounter
pure $ Name ("~" <> T.pack (show i)) (Unique i)

genFreshTyName :: State ConvertState TyName
genFreshTyName = TyName <$> genFreshName

-- N.B. We use the counter for type names (b/c we don't have Bound ADT)
data ConvertState = ConvertState {
counter :: Int,
Expand Down Expand Up @@ -200,6 +204,7 @@ toPIRType = \case
C.String -> TyBuiltin () (SomeTypeIn DefaultUniString)
C.Char -> TyBuiltin () (SomeTypeIn DefaultUniInteger)
C.Int -> TyBuiltin () (SomeTypeIn DefaultUniInteger)
C.Boolean -> TyBuiltin () (SomeTypeIn DefaultUniBool)
other -> error $ "unsupported prim tycon: " <> show other


Expand All @@ -217,9 +222,11 @@ mkKind (Just t) = foldr1 (PIR.KindArrow ()) (collect t)
-- TODO: Real monad stack w/ errors and shit
toPIR :: forall x. (x -> Var BVar FVar) -> Exp x -> State ConvertState PIRTerm
toPIR f = \case
V (f -> F (FVar _ ident)) -> case M.lookup (showIdent' ident) defaultFunMap of
Just aBuiltin -> pure $ Builtin () aBuiltin
Nothing -> error $ showIdent' ident <> " isn't a builtin, and it shouldn't be possible to have a free variable that's anything but a builtin"
V (f -> F (FVar _ ident)) -> do
let nm = showIdent' ident
case M.lookup (showIdent' ident) defaultFunMap of
Just aBuiltin -> pure $ Builtin () aBuiltin
Nothing -> do error $ showIdent' ident <> " isn't a builtin, and it shouldn't be possible to have a free variable that's anything but a builtin"
V (f -> B (BVar n t i)) -> PIR.Var () <$> mkTermName (runIdent i)
LitE _ lit -> litToPIR f lit
CtorE ty _ cn _ -> gets ctorDict >>= \dict -> do
Expand Down Expand Up @@ -248,10 +255,8 @@ toPIR f = \case
- c) Apply a to b. Hopefully!
-}
CaseE ty [scrutinee] alts -> do
scrutineeCtor <- locally $ assembleScrutinee scrutinee alts
branchEliminators <- locally $ assembleBranches alts
ty' <- toPIRType ty
pure $ PIR.Case () ty' scrutineeCtor branchEliminators
scrutinee' <- toPIR f scrutinee
assembleScrutinee scrutinee' ty alts
-- TODO: I'm just going to mark all binding groups as recursive for now and do
-- the sophisticated thing later. so tired.
LetE ty cxtmap binds exp -> do
Expand All @@ -273,16 +278,90 @@ toPIR f = \case
exp' <- toPIR (>>= f) $ instantiateEither (either (IR.V . B) (IR.V . F)) scopedExp
pure $ TermBind () Strict (VarDecl () nm ty') exp'

assembleBranches :: [Alt Exp x] -> State ConvertState [Term TyName Name DefaultUni DefaultFun ()]
assembleBranches = error "Unimplemented: AssembleBranches"
assembleScrutinee :: PIRTerm -> Ty -> [Alt Exp x] -> State ConvertState (Term TyName Name DefaultUni DefaultFun ())
assembleScrutinee scrut tx alts = do
let _binders = IR.getPat <$> alts
-- TODO: remove when implementing multi scrutinee
binders = map head _binders
alted <- unzip <$> traverse (locally . goAlt tx) alts
let sopSchema = head . fst $ alted
ctorNumberedBranches = snd alted
failBranches <- traverse mkFailBranch sopSchema
tx' <- toPIRType tx
let resolvedBranches = foldr (\(i,x) acc -> acc & ix i .~ x) failBranches ctorNumberedBranches
pure $ Case () tx' scrut resolvedBranches
where
boolT = TyBuiltin () (SomeTypeIn DefaultUniBool)
mkFailBranch :: [Ty] -> State ConvertState PIRTerm
mkFailBranch [] = pure $ mkConstant () False
mkFailBranch [t] = do
lamName <- genFreshName
t' <- toPIRType t
pure $ PIR.LamAbs () lamName (PIR.TyFun () t' boolT) (mkConstant () False)
mkFailBranch (t:ts) = do
lamName <- genFreshName
t' <- toPIRType t
rest <- mkFailBranch ts
pure $ PIR.LamAbs () lamName (PIR.TyFun () t' boolT) rest

goAlt :: Ty -> Alt Exp x -> State ConvertState ([[Ty]],(Int,PIRTerm))
goAlt t (UnguardedAlt _ [pat] body) = do
body' <- toPIR (>>= f) $ instantiateEither (either (IR.V . B) (IR.V . F)) body
patToBoolFunc body' t pat
where

patToBoolFunc :: PIRTerm -> Ty -> Pat Exp x -> State ConvertState ([[Ty]],(Int,PIRTerm))
patToBoolFunc res tx = \case
ConP tn cn ips -> do
ctordict <- gets ctorDict
tcdict <- gets tyConDict
tx' <- toPIRType tx
let cn' = disqualify cn
tn' = disqualify tn
case (M.lookup cn' ctordict, M.lookup tn' tcdict) of
(Just (_,ctorix,tys), Just (_,_,ctors)) -> do
let ctorTypes = map snd . sortOn fst $ ctors
ibfs <- goCtorArgs (zip tys ips)
pure (ctorTypes,(ctorix,ibfs))
where
goCtorArgs :: [(Ty,Pat Exp x)] -> State ConvertState PIRTerm
goCtorArgs [] = pure res
goCtorArgs [(t,VarP nm)] = do
nm' <- mkTermName (runIdent nm)
t' <- toPIRType t
pure $ LamAbs () nm' t' undefined
goCtorArgs ((t,VarP nm):rest) = do
nm' <- mkTermName (runIdent nm)
t' <- toPIRType t
rest' <- goCtorArgs rest
pure $ LamAbs () nm' t' rest'






-- NOTE: We don't have force/delay in PIR so I think we have to use type abstraction/instantiation
-- force ((\cond -> IfThenElse cond (delay caseT) (delay caseF)) cond)
-- TyInst _
-- This is probably wrong and we need quantifiers (see https://github.com/IntersectMBO/plutus/blob/381172295c0b0a8f17450b8377ee5905f03d294b/plutus-core/plutus-ir/src/PlutusIR/Transform/NonStrict.hs#L82-L95)
-- TyInst ann (Var ann name) (TyForall ann a (Type ann) (TyVar ann a))
-- TermBind x Strict (VarDecl x' name (TyForall ann a (Type ann) ty)) (TyAbs ann a (Type ann) rhs)
pIfTe :: PIRTerm -> PIRTerm -> PIRTerm -> State ConvertState PIRTerm
pIfTe cond troo fawlse = do
scrutineeNm <- genFreshName
let scrutineeTNm = TyName scrutineeNm
let boolT = TyBuiltin () (SomeTypeIn DefaultUniBool)
builtinIfTe = Builtin () IfThenElse
ifte c t f = PIR.Apply () (PIR.Apply () (PIR.Apply () builtinIfTe c) t) f
sVar = (PIR.Var () scrutineeNm)
tBranch = (PIR.TyAbs () scrutineeTNm (PIR.Type ()) troo)
fBranch = (PIR.TyAbs () scrutineeTNm (PIR.Type ()) fawlse)
body = PIR.LamAbs () scrutineeNm boolT
$ ifte sVar tBranch fBranch
forced = TyInst () body (TyForall () scrutineeTNm (PIR.Type ()) (PIR.TyVar () scrutineeTNm))
pure $ PIR.Apply () forced cond

assembleScrutinee :: Exp x -> [Alt Exp x] -> State ConvertState (Term TyName Name DefaultUni DefaultFun ())
assembleScrutinee ex alts = do
let _binders = IR.getPat <$> alts
-- TODO: remove when implementing multi scrutinee
binders = map head _binders
-- we turn binders into lambdas that return the variables bound in the alts in a ctor, indexed by the position of the alt
error "Unimplemented: AssembleScrutinee"

argTy :: Ty -> Ty
argTy (a :~> b) = a
Expand Down

0 comments on commit 3f8183d

Please sign in to comment.