diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index efd77359d860..c11eabf36b05 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -1900,6 +1900,10 @@ abbrev isDefEqGuarded (t s : Expr) : MetaM Bool := def isDefEqNoConstantApprox (t s : Expr) : MetaM Bool := approxDefEq <| isDefEq t s +/-- Shorthand for `isDefEq (mkMVar mvarId) val` -/ +def _root_.Lean.MVarId.checkedAssign (mvarId : MVarId) (val : Expr) : MetaM Bool := + isDefEq (mkMVar mvarId) val + /-- Eta expand the given expression. Example: diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index f08ffd1a273b..7e77eb078c6a 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -929,6 +929,91 @@ partial def check end CheckAssignmentQuick +/-- +Auxiliary function used at `typeOccursCheckImp`. +Given `type`, it tries to eliminate "dependencies". For example, suppose we are trying to +perform the assignment `?m := f (?n a b)` where +``` +?n : let k := g ?m; A -> h k ?m -> C +``` +If we just perform occurs check `?m` at the type of `?n`, we get a failure, but +we claim these occurrences are ok because the type `?n a b : C`. +In the example above, `typeOccursCheckImp` invokes this function with `n := 2`. +Note that we avoid using `whnf` and `inferType` at `typeOccursCheckImp` to minimize the +performance impact of this extra check. + +See test `typeOccursCheckIssue.lean` for an example where this refinement is needed. +The test is derived from a Mathlib file. +-/ +private partial def skipAtMostNumBinders (type : Expr) (n : Nat) : Expr := + match type, n with + | .forallE _ _ b _, n+1 => skipAtMostNumBinders b n + | .mdata _ b, n => skipAtMostNumBinders b n + | .letE _ _ v b _, n => skipAtMostNumBinders (b.instantiate1 v) n + | type, _ => type + +/-- `typeOccursCheck` implementation using unsafe (i.e., pointer equality) features. -/ +private unsafe def typeOccursCheckImp (mctx : MetavarContext) (mvarId : MVarId) (v : Expr) : Bool := + if v.hasExprMVar then + visit v |>.run' mkPtrSet + else + true +where + alreadyVisited (e : Expr) : StateM (PtrSet Expr) Bool := do + if (← get).contains e then + return true + else + modify fun s => s.insert e + return false + occursCheck (type : Expr) : Bool := + let go : StateM MetavarContext Bool := do + Lean.occursCheck mvarId type + -- Remark: it is ok to discard the the "updated" `MetavarContext` because + -- this function assumes all assigned metavariables have already been + -- instantiated. + go.run' mctx + visitMVar (mvarId' : MVarId) (numArgs : Nat := 0) : Bool := + if let some mvarDecl := mctx.findDecl? mvarId' then + occursCheck (skipAtMostNumBinders mvarDecl.type numArgs) + else + false + visitApp (e : Expr) : StateM (PtrSet Expr) Bool := + e.withApp fun f args => do + unless (← args.allM visit) do + return false + if f.isMVar then + return visitMVar f.mvarId! args.size + else + visit f + visit (e : Expr) : StateM (PtrSet Expr) Bool := do + if !e.hasExprMVar then + return true + else if (← alreadyVisited e) then + return true + else match e with + | .mdata _ b => visit b + | .proj _ _ s => visit s + | .app .. => visitApp e + | .lam _ d b _ => visit d <&&> visit b + | .forallE _ d b _ => visit d <&&> visit b + | .letE _ t v b _ => visit t <&&> visit v <&&> visit b + | .mvar mvarId' => return visitMVar mvarId' + | .bvar .. | .sort .. | .const .. | .fvar .. + | .lit .. => return true + +/-- +Check whether there are invalid occurrences of `mvarId` in the type of other metavariables in `v`. +For example, suppose we have +``` +?m_1 : Nat +?m_2 : Fin ?m_1 +``` +The assignment `?m_1 := (?m_2).1` should not be accepted. +See issue #4405 for additional examples. +-/ +private def typeOccursCheck (mctx : MetavarContext) (mvarId : MVarId) (v : Expr) : Bool := + unsafe typeOccursCheckImp mctx mvarId v + /-- Auxiliary function for handling constraints of the form `?m a₁ ... aₙ =?= v`. It will check whether we can perform the assignment @@ -951,11 +1036,15 @@ def checkAssignment (mvarId : MVarId) (fvars : Array Expr) (v : Expr) : MetaM (O let hasCtxLocals := fvars.any fun fvar => mvarDecl.lctx.containsFVar fvar let ctx ← read let mctx ← getMCtx - if CheckAssignmentQuick.check hasCtxLocals mctx ctx.lctx mvarDecl mvarId fvars v then - pure (some v) + let v ← if CheckAssignmentQuick.check hasCtxLocals mctx ctx.lctx mvarDecl mvarId fvars v then + pure v + else if let some v ← CheckAssignment.checkAssignmentAux mvarId fvars hasCtxLocals (← instantiateMVars v) then + pure v else - let v ← instantiateMVars v - CheckAssignment.checkAssignmentAux mvarId fvars hasCtxLocals v + return none + unless typeOccursCheck (← getMCtx) mvarId v do + return none + return some v private def processAssignmentFOApproxAux (mvar : Expr) (args : Array Expr) (v : Expr) : MetaM Bool := match v with diff --git a/src/Lean/MetavarContext.lean b/src/Lean/MetavarContext.lean index c6c743d8006a..9895afd991cf 100644 --- a/src/Lean/MetavarContext.lean +++ b/src/Lean/MetavarContext.lean @@ -341,6 +341,10 @@ class MonadMCtx (m : Type → Type) where getMCtx : m MetavarContext modifyMCtx : (MetavarContext → MetavarContext) → m Unit +instance : MonadMCtx (StateM MetavarContext) where + getMCtx := get + modifyMCtx := modify + export MonadMCtx (getMCtx modifyMCtx) @[always_inline] diff --git a/tests/lean/run/4405.lean b/tests/lean/run/4405.lean new file mode 100644 index 000000000000..65e4a429de5a --- /dev/null +++ b/tests/lean/run/4405.lean @@ -0,0 +1,54 @@ +import Lean.Elab.Command + +/-- +error: application type mismatch + ⟨Nat.lt_irrefl ↑(?m.58 n), Fin.is_lt (?m.58 n)⟩ +argument + Fin.is_lt (?m.58 n) +has type + ↑(?m.58 n) < ?m.57 n : Prop +but is expected to have type + ↑(?m.58 n) < ↑(?m.58 n) : Prop +-/ +#guard_msgs in +def foo := fun n => (not_and_self_iff _).mp ⟨Nat.lt_irrefl _, Fin.is_lt _⟩ + +/-- +error: type mismatch + Fin.is_lt ?m.185 +has type + ↑?m.185 < ?m.184 : Prop +but is expected to have type + ?a < ?a : Prop +--- +error: unsolved goals +case a +⊢ Nat + +this : ?a < ?a +⊢ True +-/ +#guard_msgs in +def test : True := by + have : ((?a : Nat) < ?a : Prop) := by + refine Fin.is_lt ?_ + done + done + +open Lean Meta +/-- +info: Defeq?: false +--- +info: fun x_0 x_1 => x_1 +-/ +#guard_msgs in +run_meta do + let mvarIdNat ← mkFreshExprMVar (.some (.const ``Nat [])) + let mvarIdFin ← mkFreshExprMVar (.some (.app (.const `Fin []) mvarIdNat)) + -- mvarIdNat.assign (.app (.const ``Fin.val []) mvaridFin)) + let b ← isDefEq mvarIdNat (mkApp2 (.const ``Fin.val []) mvarIdNat mvarIdFin) + logInfo m!"Defeq?: {b}" -- prints true + -- Now mvaridNat occurs in its own type + -- This will stack overflow + let r ← abstractMVars mvarIdFin (levels := false) + logInfo m!"{r.expr}" diff --git a/tests/lean/run/typeOccursCheckIssue.lean b/tests/lean/run/typeOccursCheckIssue.lean new file mode 100644 index 000000000000..7c8aaeb70efe --- /dev/null +++ b/tests/lean/run/typeOccursCheckIssue.lean @@ -0,0 +1,76 @@ +namespace SlimCheck + +inductive TestResult (p : Prop) where + | success : PSum Unit p → TestResult p + | gaveUp : Nat → TestResult p + | failure : ¬ p → List String → Nat → TestResult p + deriving Inhabited + +/-- Configuration for testing a property. -/ +structure Configuration where + numInst : Nat := 100 + maxSize : Nat := 100 + numRetries : Nat := 10 + traceDiscarded : Bool := false + traceSuccesses : Bool := false + traceShrink : Bool := false + traceShrinkCandidates : Bool := false + randomSeed : Option Nat := none + quiet : Bool := false + deriving Inhabited + +abbrev Rand := Id + +abbrev Gen (α : Type u) := ReaderT (ULift Nat) Rand α + +/-- `Testable p` uses random examples to try to disprove `p`. -/ +class Testable (p : Prop) where + run (cfg : Configuration) (minimize : Bool) : Gen (TestResult p) + +def NamedBinder (_n : String) (p : Prop) : Prop := p + +namespace TestResult + +def isFailure : TestResult p → Bool + | failure _ _ _ => true + | _ => false + +end TestResult + +namespace Testable + +open TestResult + +def runProp (p : Prop) [Testable p] : Configuration → Bool → Gen (TestResult p) := Testable.run + +variable {var : String} + +def addShrinks (n : Nat) : TestResult p → TestResult p + | TestResult.failure p xs m => TestResult.failure p xs (m + n) + | p => p + +instance [Pure m] : Inhabited (OptionT m α) := ⟨(pure none : m (Option α))⟩ + +class Shrinkable (α : Type u) where + shrink : (x : α) → List α := fun _ ↦ [] + +class SampleableExt (α : Sort u) where + proxy : Type v + [proxyRepr : Repr proxy] + [shrink : Shrinkable proxy] + sample : Gen proxy + interp : proxy → α + +partial def minimizeAux [SampleableExt α] {β : α → Prop} [∀ x, Testable (β x)] (cfg : Configuration) + (var : String) (x : SampleableExt.proxy α) (n : Nat) : + OptionT Gen (Σ x, TestResult (β (SampleableExt.interp x))) := do + let candidates := SampleableExt.shrink.shrink x + for candidate in candidates do + let res ← OptionT.lift <| Testable.runProp (β (SampleableExt.interp candidate)) cfg true + if res.isFailure then + if cfg.traceShrink then + pure () -- slimTrace s!"{var} shrunk to {repr candidate} from {repr x}" + let currentStep := OptionT.lift <| pure <| Sigma.mk candidate (addShrinks (n + 1) res) + let nextStep := minimizeAux cfg var candidate (n + 1) + return ← (nextStep <|> currentStep) + failure