diff --git a/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck.hs b/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck.hs index a2fff947..506277ef 100644 --- a/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck.hs +++ b/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck.hs @@ -19,11 +19,12 @@ import Control.Monad.Freer.TH (makeEffect) import Data.Default (Default (def)) import Data.Foldable (Foldable (toList), traverse_) import Data.Map qualified as M -import LambdaBuffers.Compiler.KindCheck.Derivation (Context, constraintContext, context) +import LambdaBuffers.Compiler.KindCheck.Derivation (Context, classContext, context) import LambdaBuffers.Compiler.KindCheck.Inference (protoKind2Kind) import LambdaBuffers.Compiler.KindCheck.Inference qualified as I import LambdaBuffers.Compiler.KindCheck.Kind (Kind (KConstraint, KType, KVar, (:->:))) import LambdaBuffers.Compiler.KindCheck.Type ( + QualifiedTyClassRefName (QualifiedTyClassRefName), Variable (QualifiedTyClassRef, QualifiedTyRef, TyVar), fcrISOqtcr, ftrISOqtr, @@ -302,11 +303,13 @@ kind2ProtoKind = \case -------------------------------------------------------------------------------- -- Class Definition Based Context Building. -classDef2Context :: forall effs. PC.ClassDef -> Eff effs Context +classDef2Context :: forall effs. Member (Reader PC.ModuleName) effs => PC.ClassDef -> Eff effs Context classDef2Context cDef = do - let className = mkInfoLess . view #className $ cDef + modName <- ask + let className = cDef ^. #className + let qtcn = mkInfoLess . QualifiedTyClassRef $ QualifiedTyClassRefName className modName def let classArg = tyArg2Kind . view #classArgs $ cDef - pure $ mempty & constraintContext .~ M.singleton className classArg + pure $ mempty & classContext .~ M.singleton qtcn (classArg :->: KConstraint) -------------------------------------------------------------------------------- -- utilities diff --git a/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck/Derivation.hs b/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck/Derivation.hs index f496ab43..b86a27e4 100644 --- a/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck/Derivation.hs +++ b/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck/Derivation.hs @@ -1,3 +1,5 @@ +{-# OPTIONS_GHC -Wno-unused-imports #-} + module LambdaBuffers.Compiler.KindCheck.Derivation ( Derivation (Axiom, Abstraction, Application, Implication), QClassName (QClassName), @@ -11,15 +13,14 @@ module LambdaBuffers.Compiler.KindCheck.Derivation ( context, addContext, getAllContext, - constraintContext, + classContext, ) where import Control.Lens (Lens', lens, makeLenses, (&), (.~), (^.)) import Control.Lens.Operators ((%~)) import Data.Map qualified as M import LambdaBuffers.Compiler.KindCheck.Kind (Kind) -import LambdaBuffers.Compiler.KindCheck.Type (Type, Variable) -import LambdaBuffers.Compiler.ProtoCompat qualified as PC +import LambdaBuffers.Compiler.KindCheck.Type (Type, Variable (QualifiedTyClassRef)) import LambdaBuffers.Compiler.ProtoCompat.InfoLess (InfoLess) import Prettyprinter ( Doc, @@ -43,7 +44,7 @@ data QClassName = QClassName data Context = Context { _context :: M.Map (InfoLess Variable) Kind , _addContext :: M.Map (InfoLess Variable) Kind - , _constraintContext :: M.Map (InfoLess PC.ClassName) Kind + , _classContext :: M.Map (InfoLess Variable) Kind } deriving stock (Show, Eq) @@ -62,14 +63,14 @@ instance Semigroup Context where c1 & context %~ (<> c2 ^. context) & addContext %~ (<> c2 ^. addContext) - & constraintContext %~ (<> c2 ^. constraintContext) + & classContext %~ (<> c2 ^. classContext) instance Monoid Context where mempty = Context mempty mempty mempty -- | Utility to unify the two T. getAllContext :: Context -> M.Map (InfoLess Variable) Kind -getAllContext c = c ^. context <> c ^. addContext +getAllContext c = c ^. context <> c ^. addContext <> c ^. classContext data Judgement = Judgement { _j'ctx :: Context diff --git a/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck/Inference.hs b/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck/Inference.hs index d07954f6..4c46219c 100644 --- a/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck/Inference.hs +++ b/lambda-buffers-compiler/src/LambdaBuffers/Compiler/KindCheck/Inference.hs @@ -21,6 +21,7 @@ module LambdaBuffers.Compiler.KindCheck.Inference ( ) where import Control.Lens ((%~), (&), (.~), (^.)) +import Control.Lens.Combinators (to) import Control.Lens.Iso (withIso) import Control.Monad (void) import Control.Monad.Freer (Eff, Member, Members, run) @@ -47,7 +48,6 @@ import LambdaBuffers.Compiler.KindCheck.Type ( Variable (QualifiedTyClassRef, QualifiedTyRef, TyVar), fcrISOqtcr, ftrISOqtr, - lcrISOftcr, lcrISOqtcr, ltrISOqtr, ) @@ -250,8 +250,17 @@ runClassDefCheck ctx modName classDef = do -- | Checks the class definition for correct typedness. deriveClassDef :: PC.ClassDef -> Derive () -deriveClassDef classDef = - traverse_ deriveConstraint (classDef ^. #supers) +deriveClassDef classDef = do + vars <- createLocalConstraintContext classDef + traverse_ (local (<> vars) . deriveConstraint) (classDef ^. #supers) + +-- | Adds the kind of the variable to the local context. +createLocalConstraintContext :: PC.ClassDef -> Derive Context +createLocalConstraintContext cd = do + let arg = cd ^. #classArgs + let n = mkInfoLess $ TyVar $ arg ^. #argName + let k = arg ^. #argKind . to protoKind2Kind + return $ mempty & addContext .~ M.singleton n k deriveConstraint :: PC.Constraint -> Derive Derivation deriveConstraint constraint = do @@ -397,5 +406,5 @@ protoKind2Kind = \case PC.Kind k -> case k of PC.KindArrow k1 k2 -> protoKind2Kind k1 :->: protoKind2Kind k2 PC.KindRef PC.KType -> KType - PC.KindRef PC.KUnspecified -> KType PC.KindRef PC.KConstraint -> KConstraint + PC.KindRef PC.KUnspecified -> KType diff --git a/lambda-buffers-compiler/test/Test/Utils/Module.hs b/lambda-buffers-compiler/test/Test/Utils/Module.hs index dca4a077..1021c5f6 100644 --- a/lambda-buffers-compiler/test/Test/Utils/Module.hs +++ b/lambda-buffers-compiler/test/Test/Utils/Module.hs @@ -82,4 +82,4 @@ module'classEq = _Module (_ModuleName ["Module"]) mempty [classDef'Eq] mempty class Eq a => Ord a. -} module'classOrd :: PC.Module -module'classOrd = _Module (_ModuleName ["Module"]) mempty [classDef'Ord] mempty +module'classOrd = _Module (_ModuleName ["Module"]) mempty [classDef'Eq, classDef'Ord] mempty