Skip to content

Commit

Permalink
Minor optimisations to unification (#859)
Browse files Browse the repository at this point in the history
* Minor optimisations to unification

* Add CHANGELOG
  • Loading branch information
MatthewDaggitt authored Nov 7, 2024
1 parent a430570 commit 1c0c9d5
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 84 deletions.
2 changes: 2 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Version 0.16

* Decreased type-checking time by ~50%

* Decreased the size of generated verification plan files by 75%

* Improved the ordering of constraints in generated query files.
Expand Down
12 changes: 1 addition & 11 deletions vehicle/src/Vehicle/Compile/Type/Constraint/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ module Vehicle.Compile.Type.Constraint.Core
extractHeadFromInstanceCandidate,
findInstanceGoalHead,
parseInstanceGoal,
unify,
createInstanceUnification,
createSubInstance,
mkCandidate,
Expand Down Expand Up @@ -71,15 +70,6 @@ malformedConstraintError :: (PrintableBuiltin builtin, MonadCompile m) => WithCo
malformedConstraintError c =
compilerDeveloperError $ "Malformed type-class constraint:" <+> prettyVerbose c

-- | Create a new unification constraint, copying the context as appropriate.
unify ::
(MonadTypeChecker builtin m) =>
(ConstraintContext builtin, UnificationConstraintOrigin builtin) ->
(Value builtin, Value builtin) ->
m (WithContext (UnificationConstraint builtin))
unify (ctx, origin) (e1, e2) =
WithContext (Unify origin e1 e2) <$> copyContext ctx

-- | Create a new unification constraint as a subgoal of an existing instance constraint.
createInstanceUnification ::
(MonadTypeChecker builtin m) =>
Expand All @@ -89,7 +79,7 @@ createInstanceUnification ::
m (WithContext (Constraint builtin))
createInstanceUnification (ctx, origin) e1 e2 = do
let unifyOrigin = CheckingInstanceType origin
constraint <- unify (ctx, unifyOrigin) (e1, e2)
constraint <- WithContext (Unify unifyOrigin e1 e2) <$> copyContext ctx
return $ mapObject UnificationConstraint constraint

-- | Creates an instance constraint as a subgoal of an existing instance constraint.
Expand Down
182 changes: 109 additions & 73 deletions vehicle/src/Vehicle/Compile/Type/Constraint/UnificationSolver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Data.IntMap (IntMap)
import Data.IntMap qualified as IntMap
import Data.IntSet qualified as IntSet
import Data.List (intersect)
import Data.Maybe (mapMaybe)
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Proxy (Proxy (..))
import Prettyprinter (sep)
import Vehicle.Compile.Error
Expand All @@ -19,7 +19,7 @@ import Vehicle.Compile.Normalise.Quote (Quote (..), unnormalise)
import Vehicle.Compile.Prelude
import Vehicle.Compile.Print (prettyFriendly, prettyVerbose)
import Vehicle.Compile.Type.Builtin (TypableBuiltin (..))
import Vehicle.Compile.Type.Constraint.Core (runConstraintSolver, unify)
import Vehicle.Compile.Type.Constraint.Core (runConstraintSolver)
import Vehicle.Compile.Type.Core
import Vehicle.Compile.Type.Force (forceHead)
import Vehicle.Compile.Type.Meta
Expand Down Expand Up @@ -58,57 +58,89 @@ solveUnificationConstraint ::
(MonadUnify builtin m) =>
WithContext (UnificationConstraint builtin) ->
m ()
solveUnificationConstraint (WithContext (Unify origin' e1 e2) ctx) = do
solveUnificationConstraint constraint = do
result <- solve constraint
case result of
Success -> return ()
Blocked blockedConstraints ->
addUnificationConstraints blockedConstraints
HardFailure ->
throwError $ TypingError $ FailedUnificationConstraints [constraint]

solve ::
forall builtin m.
(MonadUnify builtin m) =>
WithContext (UnificationConstraint builtin) ->
m (UnificationResult builtin)
solve (WithContext (Unify origin e1 e2) ctx) = do
-- Force the heads of both expressions
metaSubst <- getMetaSubstitution (Proxy @builtin)
(ne1', e1BlockingMetas) <- forceHead metaSubst ctx e1
(ne2', e2BlockingMetas) <- forceHead metaSubst ctx e2
(ne1, e1BlockingMetas) <- forceHead metaSubst ctx e1
(ne2, e2BlockingMetas) <- forceHead metaSubst ctx e2

-- In theory this substitution shouldn't be needed, but in practice it is as if
-- not all the meta-variables are substituted through then the scope of some
-- meta-variables may be larger than the current scope of the constraint.
-- These dependencies only disappear on substitution. Need to work out how to
-- avoid doing this.
nu@(Unify origin ne1 ne2) <- substMetas (Unify origin' ne1' ne2')
-- Construct the new constraint information
let blockingMetas = e1BlockingMetas <> e2BlockingMetas
let updatedConstraint = WithContext (Unify origin ne1 ne2) ctx
let constraintInfo = (updatedConstraint, blockingMetas)

result <- unification (ctx, origin) (e1BlockingMetas <> e2BlockingMetas) (ne1, ne2)
case result of
Success newConstraints -> do
addUnificationConstraints newConstraints
SoftFailure blockingMetas -> do
let normConstraint = WithContext nu ctx
let blockedConstraint = blockConstraintOn normConstraint blockingMetas
addUnificationConstraints [blockedConstraint]
HardFailure ->
throwError $ TypingError $ FailedUnificationConstraints [WithContext nu ctx]
-- Perform the unification
unification constraintInfo (ne1, ne2)

data UnificationResult builtin
= Success [WithContext (UnificationConstraint builtin)]
= Success
| -- | Always an error
HardFailure
| -- | Only an error when further reduction will never occur.
SoftFailure MetaSet
Blocked [WithContext (UnificationConstraint builtin)]

instance Semigroup (UnificationResult builtin) where
HardFailure <> _ = HardFailure
_ <> HardFailure = HardFailure
SoftFailure m1 <> SoftFailure m2 = SoftFailure (m1 <> m2)
r1@SoftFailure {} <> _ = r1
_ <> r2@SoftFailure {} = r2
Success cs1 <> Success cs2 = Success (cs1 <> cs2)
Blocked m1 <> Blocked m2 = Blocked (m1 <> m2)
r1@Blocked {} <> _ = r1
_ <> r2@Blocked {} = r2
Success <> Success = Success

instance Monoid (UnificationResult builtin) where
mempty = Success mempty
mempty = Success

type ConstraintInfo builtin = (WithContext (UnificationConstraint builtin), MetaSet)

ctxOf :: ConstraintInfo builtin -> ConstraintContext builtin
ctxOf (WithContext _ ctx, _) = ctx

-- | Create a new unification constraint, copying the context as appropriate.
subUnify ::
(MonadTypeChecker builtin m) =>
ConstraintInfo builtin ->
(Value builtin, Value builtin) ->
m (UnificationResult builtin)
subUnify (WithContext (Unify origin _ _) ctx, _) (e1, e2) =
solve . WithContext (Unify origin e1 e2) =<< copyContext ctx

block ::
(MonadUnify builtin m) =>
ConstraintInfo builtin ->
Maybe MetaSet ->
m (UnificationResult builtin)
block (WithContext constraint ctx, originalBlockingMetas) maybeRefinedBlockingMetas = do
let blockingMetas = fromMaybe originalBlockingMetas maybeRefinedBlockingMetas
if MetaSet.null blockingMetas
then return HardFailure
else do
newConstraint <- WithContext constraint <$> copyContext ctx
let blockedConstraint = blockConstraintOn newConstraint blockingMetas
return $ Blocked [blockedConstraint]

pattern (:~:) :: a -> b -> (a, b)
pattern x :~: y = (x, y)

unification ::
(MonadUnify builtin m) =>
(ConstraintContext builtin, UnificationConstraintOrigin builtin) ->
MetaSet ->
ConstraintInfo builtin ->
(Value builtin, Value builtin) ->
m (UnificationResult builtin)
unification info@(ctx, _) blockingMetas = \case
unification info = \case
-----------------------
-- Rigid-rigid cases --
-----------------------
Expand All @@ -132,60 +164,53 @@ unification info@(ctx, _) blockingMetas = \case
| meta1 == meta2 -> solveSpine info spine1 spine2
-- The longer spine normally means its in a deeper scope. This minor
-- optimisation tries to solve the deeper meta first.
| length spine1 < length spine2 -> solveFlexFlex ctx (meta2, spine2) (meta1, spine1)
| otherwise -> solveFlexFlex ctx (meta1, spine1) (meta2, spine2)
| length spine1 < length spine2 -> solveFlexFlex info (meta2, spine2) (meta1, spine1)
| otherwise -> solveFlexFlex info (meta1, spine1) (meta2, spine2)
----------------------
-- Flex-rigid cases --
----------------------
VMeta meta spine :~: e -> solveFlexRigid ctx (meta, spine) e
e :~: VMeta meta spine -> solveFlexRigid ctx (meta, spine) e
VMeta meta spine :~: e -> solveFlexRigid info (meta, spine) e
e :~: VMeta meta spine -> solveFlexRigid info (meta, spine) e
------------------
-- Blocked case --
------------------
_ ->
return $
if MetaSet.null blockingMetas
then HardFailure
else SoftFailure blockingMetas
_ -> block info Nothing

solveTrivially :: (MonadUnify builtin m) => m (UnificationResult builtin)
solveTrivially = do
logDebug MaxDetail "solved-trivially"
return $ Success mempty
return Success

solveArg ::
(MonadUnify builtin m) =>
(ConstraintContext builtin, UnificationConstraintOrigin builtin) ->
ConstraintInfo builtin ->
(VArg builtin, VArg builtin) ->
Maybe (m (UnificationResult builtin))
m (UnificationResult builtin)
solveArg info (arg1, arg2)
| not (visibilityMatches arg1 arg2) = Just $ return HardFailure
| isInstance arg1 = Nothing
| otherwise = Just $ do
argEq <- unify info (argExpr arg1, argExpr arg2)
return $ Success [argEq]
| not (visibilityMatches arg1 arg2) = return HardFailure
-- Don't unify instances, they should be uniquely determined by the type.
| isInstance arg1 = return Success
| otherwise = subUnify info (argExpr arg1, argExpr arg2)

solveSpine ::
(MonadUnify builtin m) =>
(ConstraintContext builtin, UnificationConstraintOrigin builtin) ->
ConstraintInfo builtin ->
Spine builtin ->
Spine builtin ->
m (UnificationResult builtin)
solveSpine info args1 args2
| length args1 /= length args2 = return HardFailure
| otherwise = do
constraints <- sequence $ mapMaybe (solveArg info) (zip args1 args2)
return $ mconcat constraints
| otherwise = mconcat <$> traverse (solveArg info) (zip args1 args2)

solveLam ::
(MonadUnify builtin m) =>
(ConstraintContext builtin, UnificationConstraintOrigin builtin) ->
ConstraintInfo builtin ->
(VBinder builtin, Closure builtin) ->
(VBinder builtin, Closure builtin) ->
m (UnificationResult builtin)
solveLam info@(ctx, origin) (binder1, Closure env1 body1) (binder2, Closure env2 body2) = do
solveLam info@(WithContext constraint ctx, blockingMeta) (binder1, Closure env1 body1) (binder2, Closure env2 body2) = do
-- Unify binder constraints
binderConstraint <- unify info (typeOf binder1, typeOf binder2)
binderConstraint <- subUnify info (typeOf binder1, typeOf binder2)

-- Evaluate the normalised bodies of the lambdas
let lv = contextDBLevel ctx
Expand All @@ -195,47 +220,58 @@ solveLam info@(ctx, origin) (binder1, Closure env1 body1) (binder2, Closure env2
-- Update the context.
-- NOTE: that we have to unnormalise here indicates something is wrong.
let unnormBinder = fmap (unnormalise lv) binder1
let updatedInfo = (updateConstraintBoundCtx ctx (unnormBinder :), origin)
let newCtx = updateConstraintBoundCtx ctx (unnormBinder :)
let updatedInfo = (WithContext constraint newCtx, blockingMeta)

-- Unify the two bodies
bodyConstraint <- unify updatedInfo (nbody1, nbody2)
bodyConstraint <- subUnify updatedInfo (nbody1, nbody2)

-- Return the result
return $ Success [binderConstraint, bodyConstraint]
return $ binderConstraint <> bodyConstraint

solvePi ::
(MonadUnify builtin m) =>
(ConstraintContext builtin, UnificationConstraintOrigin builtin) ->
ConstraintInfo builtin ->
(VBinder builtin, Value builtin) ->
(VBinder builtin, Value builtin) ->
m (UnificationResult builtin)
solvePi info (binder1, body1) (binder2, body2) = do
-- !!TODO!! Block until binders are solved
-- One possible implementation, blocked metas = set of sets where outer is conjunction and inner is disjunction
-- BOB: this effectively blocks until the binders are solved, because we usually just try to eagerly solve problems
binderConstraint <- unify info (typeOf binder1, typeOf binder2)
bodyConstraint <- unify info (body1, body2)
return $ Success [binderConstraint, bodyConstraint]
binderConstraint <- subUnify info (typeOf binder1, typeOf binder2)
bodyConstraint <- subUnify info (body1, body2)
return $ binderConstraint <> bodyConstraint

solveFlexFlex :: (MonadUnify builtin m) => ConstraintContext builtin -> (MetaID, Spine builtin) -> (MetaID, Spine builtin) -> m (UnificationResult builtin)
solveFlexFlex ctx (meta1, spine1) (meta2, spine2) = do
solveFlexFlex ::
(MonadUnify builtin m) =>
ConstraintInfo builtin ->
(MetaID, Spine builtin) ->
(MetaID, Spine builtin) ->
m (UnificationResult builtin)
solveFlexFlex info (meta1, spine1) (meta2, spine2) = do
-- It may be that only one of the two spines is invertible
maybeRenaming <- invert (contextDBLevel ctx) (meta1, spine1)
maybeRenaming <- invert (contextDBLevel (ctxOf info)) (meta1, spine1)
case maybeRenaming of
Nothing -> solveFlexRigid ctx (meta2, spine2) (VMeta meta1 spine1)
Just renaming -> solveFlexRigidWithRenaming ctx (meta1, spine1) renaming (VMeta meta2 spine2)
Nothing -> solveFlexRigid info (meta2, spine2) (VMeta meta1 spine1)
Just renaming -> solveFlexRigidWithRenaming (ctxOf info) (meta1, spine1) renaming (VMeta meta2 spine2)

solveFlexRigid :: (MonadUnify builtin m) => ConstraintContext builtin -> (MetaID, Spine builtin) -> Value builtin -> m (UnificationResult builtin)
solveFlexRigid ctx (metaID, spine) solution = do
solveFlexRigid ::
(MonadUnify builtin m) =>
ConstraintInfo builtin ->
(MetaID, Spine builtin) ->
Value builtin ->
m (UnificationResult builtin)
solveFlexRigid info (metaID, spine) solution = do
-- Check that 'spine' is a pattern and try to calculate a substitution
-- that renames the variables in `solution` to ones available to `meta`
maybeRenaming <- invert (contextDBLevel ctx) (metaID, spine)
maybeRenaming <- invert (contextDBLevel (ctxOf info)) (metaID, spine)
case maybeRenaming of
Just renaming -> solveFlexRigidWithRenaming ctx (metaID, spine) renaming solution
Just renaming -> solveFlexRigidWithRenaming (ctxOf info) (metaID, spine) renaming solution
-- This constraint is stuck because it is not pattern; shelve
-- it for now and hope that another constraint allows us to
-- progress.
Nothing -> return $ SoftFailure $ MetaSet.singleton metaID
Nothing -> block info (Just (MetaSet.singleton metaID))

solveFlexRigidWithRenaming ::
forall builtin m.
Expand All @@ -254,7 +290,7 @@ solveFlexRigidWithRenaming ctx meta@(metaID, _) renaming solution = do
let unnormSolution = quote mempty (contextDBLevel ctx) prunedSolution
let substSolution = substDBAll 0 (\v -> unIx v `IntMap.lookup` renaming) unnormSolution
solveMeta metaID substSolution (boundContext ctx)
return $ Success mempty
return Success

pruneMetaDependencies ::
forall builtin m.
Expand Down

0 comments on commit 1c0c9d5

Please sign in to comment.