diff --git a/src/Language/PureScript/CoreFn/Convert/ToPIR.hs b/src/Language/PureScript/CoreFn/Convert/ToPIR.hs index 323c06c9..751b2a61 100644 --- a/src/Language/PureScript/CoreFn/Convert/ToPIR.hs +++ b/src/Language/PureScript/CoreFn/Convert/ToPIR.hs @@ -10,7 +10,7 @@ module Language.PureScript.CoreFn.Convert.ToPIR where import Prelude -import Language.PureScript.Names (Qualified (..), ProperName(..), runIdent, pattern ByThisModuleName) +import Language.PureScript.Names (Qualified (..), ProperName(..), runIdent, pattern ByThisModuleName, disqualify) import Language.PureScript.CoreFn.FromJSON () import Data.Text qualified as T import Language.PureScript.PSString (prettyPrintString) @@ -46,7 +46,8 @@ import PlutusIR.Error import Control.Monad.Reader import PlutusCore qualified as PLC import Control.Exception - +import Data.List (sortOn) +import Control.Lens ((&),(.~),ix) type PLCProgram uni fun a = PLC.Program PLC.TyName PLC.Name uni fun (Provenance a) fuckThisMonadStack :: @@ -137,6 +138,9 @@ genFreshName = do i <- getCounter pure $ Name ("~" <> T.pack (show i)) (Unique i) +genFreshTyName :: State ConvertState TyName +genFreshTyName = TyName <$> genFreshName + -- N.B. We use the counter for type names (b/c we don't have Bound ADT) data ConvertState = ConvertState { counter :: Int, @@ -200,6 +204,7 @@ toPIRType = \case C.String -> TyBuiltin () (SomeTypeIn DefaultUniString) C.Char -> TyBuiltin () (SomeTypeIn DefaultUniInteger) C.Int -> TyBuiltin () (SomeTypeIn DefaultUniInteger) + C.Boolean -> TyBuiltin () (SomeTypeIn DefaultUniBool) other -> error $ "unsupported prim tycon: " <> show other @@ -217,9 +222,11 @@ mkKind (Just t) = foldr1 (PIR.KindArrow ()) (collect t) -- TODO: Real monad stack w/ errors and shit toPIR :: forall x. (x -> Var BVar FVar) -> Exp x -> State ConvertState PIRTerm toPIR f = \case - V (f -> F (FVar _ ident)) -> case M.lookup (showIdent' ident) defaultFunMap of - Just aBuiltin -> pure $ Builtin () aBuiltin - Nothing -> error $ showIdent' ident <> " isn't a builtin, and it shouldn't be possible to have a free variable that's anything but a builtin" + V (f -> F (FVar _ ident)) -> do + let nm = showIdent' ident + case M.lookup (showIdent' ident) defaultFunMap of + Just aBuiltin -> pure $ Builtin () aBuiltin + Nothing -> do error $ showIdent' ident <> " isn't a builtin, and it shouldn't be possible to have a free variable that's anything but a builtin" V (f -> B (BVar n t i)) -> PIR.Var () <$> mkTermName (runIdent i) LitE _ lit -> litToPIR f lit CtorE ty _ cn _ -> gets ctorDict >>= \dict -> do @@ -248,10 +255,8 @@ toPIR f = \case - c) Apply a to b. Hopefully! -} CaseE ty [scrutinee] alts -> do - scrutineeCtor <- locally $ assembleScrutinee scrutinee alts - branchEliminators <- locally $ assembleBranches alts - ty' <- toPIRType ty - pure $ PIR.Case () ty' scrutineeCtor branchEliminators + scrutinee' <- toPIR f scrutinee + assembleScrutinee scrutinee' ty alts -- TODO: I'm just going to mark all binding groups as recursive for now and do -- the sophisticated thing later. so tired. LetE ty cxtmap binds exp -> do @@ -273,16 +278,90 @@ toPIR f = \case exp' <- toPIR (>>= f) $ instantiateEither (either (IR.V . B) (IR.V . F)) scopedExp pure $ TermBind () Strict (VarDecl () nm ty') exp' -assembleBranches :: [Alt Exp x] -> State ConvertState [Term TyName Name DefaultUni DefaultFun ()] -assembleBranches = error "Unimplemented: AssembleBranches" + assembleScrutinee :: PIRTerm -> Ty -> [Alt Exp x] -> State ConvertState (Term TyName Name DefaultUni DefaultFun ()) + assembleScrutinee scrut tx alts = do + let _binders = IR.getPat <$> alts + -- TODO: remove when implementing multi scrutinee + binders = map head _binders + alted <- unzip <$> traverse (locally . goAlt tx) alts + let sopSchema = head . fst $ alted + ctorNumberedBranches = snd alted + failBranches <- traverse mkFailBranch sopSchema + tx' <- toPIRType tx + let resolvedBranches = foldr (\(i,x) acc -> acc & ix i .~ x) failBranches ctorNumberedBranches + pure $ Case () tx' scrut resolvedBranches + where + boolT = TyBuiltin () (SomeTypeIn DefaultUniBool) + mkFailBranch :: [Ty] -> State ConvertState PIRTerm + mkFailBranch [] = pure $ mkConstant () False + mkFailBranch [t] = do + lamName <- genFreshName + t' <- toPIRType t + pure $ PIR.LamAbs () lamName (PIR.TyFun () t' boolT) (mkConstant () False) + mkFailBranch (t:ts) = do + lamName <- genFreshName + t' <- toPIRType t + rest <- mkFailBranch ts + pure $ PIR.LamAbs () lamName (PIR.TyFun () t' boolT) rest + + goAlt :: Ty -> Alt Exp x -> State ConvertState ([[Ty]],(Int,PIRTerm)) + goAlt t (UnguardedAlt _ [pat] body) = do + body' <- toPIR (>>= f) $ instantiateEither (either (IR.V . B) (IR.V . F)) body + patToBoolFunc body' t pat + where + + patToBoolFunc :: PIRTerm -> Ty -> Pat Exp x -> State ConvertState ([[Ty]],(Int,PIRTerm)) + patToBoolFunc res tx = \case + ConP tn cn ips -> do + ctordict <- gets ctorDict + tcdict <- gets tyConDict + tx' <- toPIRType tx + let cn' = disqualify cn + tn' = disqualify tn + case (M.lookup cn' ctordict, M.lookup tn' tcdict) of + (Just (_,ctorix,tys), Just (_,_,ctors)) -> do + let ctorTypes = map snd . sortOn fst $ ctors + ibfs <- goCtorArgs (zip tys ips) + pure (ctorTypes,(ctorix,ibfs)) + where + goCtorArgs :: [(Ty,Pat Exp x)] -> State ConvertState PIRTerm + goCtorArgs [] = pure res + goCtorArgs [(t,VarP nm)] = do + nm' <- mkTermName (runIdent nm) + t' <- toPIRType t + pure $ LamAbs () nm' t' undefined + goCtorArgs ((t,VarP nm):rest) = do + nm' <- mkTermName (runIdent nm) + t' <- toPIRType t + rest' <- goCtorArgs rest + pure $ LamAbs () nm' t' rest' + + + + + + + -- NOTE: We don't have force/delay in PIR so I think we have to use type abstraction/instantiation + -- force ((\cond -> IfThenElse cond (delay caseT) (delay caseF)) cond) + -- TyInst _ + -- This is probably wrong and we need quantifiers (see https://github.com/IntersectMBO/plutus/blob/381172295c0b0a8f17450b8377ee5905f03d294b/plutus-core/plutus-ir/src/PlutusIR/Transform/NonStrict.hs#L82-L95) + -- TyInst ann (Var ann name) (TyForall ann a (Type ann) (TyVar ann a)) + -- TermBind x Strict (VarDecl x' name (TyForall ann a (Type ann) ty)) (TyAbs ann a (Type ann) rhs) + pIfTe :: PIRTerm -> PIRTerm -> PIRTerm -> State ConvertState PIRTerm + pIfTe cond troo fawlse = do + scrutineeNm <- genFreshName + let scrutineeTNm = TyName scrutineeNm + let boolT = TyBuiltin () (SomeTypeIn DefaultUniBool) + builtinIfTe = Builtin () IfThenElse + ifte c t f = PIR.Apply () (PIR.Apply () (PIR.Apply () builtinIfTe c) t) f + sVar = (PIR.Var () scrutineeNm) + tBranch = (PIR.TyAbs () scrutineeTNm (PIR.Type ()) troo) + fBranch = (PIR.TyAbs () scrutineeTNm (PIR.Type ()) fawlse) + body = PIR.LamAbs () scrutineeNm boolT + $ ifte sVar tBranch fBranch + forced = TyInst () body (TyForall () scrutineeTNm (PIR.Type ()) (PIR.TyVar () scrutineeTNm)) + pure $ PIR.Apply () forced cond -assembleScrutinee :: Exp x -> [Alt Exp x] -> State ConvertState (Term TyName Name DefaultUni DefaultFun ()) -assembleScrutinee ex alts = do - let _binders = IR.getPat <$> alts - -- TODO: remove when implementing multi scrutinee - binders = map head _binders - -- we turn binders into lambdas that return the variables bound in the alts in a ctor, indexed by the position of the alt - error "Unimplemented: AssembleScrutinee" argTy :: Ty -> Ty argTy (a :~> b) = a