Skip to content

Commit

Permalink
Reverting later improves order of cases? oh wey
Browse files Browse the repository at this point in the history
  • Loading branch information
nomeata committed Feb 29, 2024
1 parent 1380721 commit d9f9e2c
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 46 deletions.
38 changes: 22 additions & 16 deletions src/Lean/Meta/Tactic/FunInd.lean
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def assertIHs (vals : Array Expr) (mvarid : MVarId) : MetaM MVarId := do
return mvarid

/-- Base case of `buildInductionBody`: Construct a case for the final induction hypthesis. -/
def buildInductionCase (motiveFVar : FVarId) (fn : Expr) (oldIH newIH : FVarId) (toClear toPreserve : Array FVarId)
def buildInductionCase (fn : Expr) (oldIH newIH : FVarId) (toClear toPreserve : Array FVarId)
(goal : Expr) (IHs : Array Expr) (e : Expr) : MetaM Expr := do
let IHs := IHs ++ (← collectIHs fn oldIH newIH e)
let IHs ← deduplicateIHs IHs
Expand All @@ -404,7 +404,6 @@ def buildInductionCase (motiveFVar : FVarId) (fn : Expr) (oldIH newIH : FVarId)
mvarId ← mvarId.clear fvarId
mvarId ← mvarId.cleanup (toPreserve := toPreserve)
mvarId ← substVars mvarId
let (_, _mvarId) ← mvarId.revertAfter motiveFVar
let mvar ← instantiateMVars mvar
pure mvar

Expand Down Expand Up @@ -443,7 +442,7 @@ def maskArray {α} (mask : Array Bool) (xs : Array α) : Array α := Id.run do
if b then ys := ys.push x
return ys

partial def buildInductionBody (motiveFVar : FVarId) (fn : Expr) (toClear toPreserve : Array FVarId)
partial def buildInductionBody (fn : Expr) (toClear toPreserve : Array FVarId)
(goal : Expr) (oldIH newIH : FVarId) (IHs : Array Expr) (e : Expr) : MetaM Expr := do

if e.isDIte then
Expand All @@ -453,11 +452,11 @@ partial def buildInductionBody (motiveFVar : FVarId) (fn : Expr) (toClear toPres
let h' ← foldCalls fn oldIH h
let t' ← withLocalDecl `h .default c' fun h => do
let t ← instantiateLambda t #[h]
let t' ← buildInductionBody motiveFVar fn toClear (toPreserve.push h.fvarId!) goal oldIH newIH IHs t
let t' ← buildInductionBody fn toClear (toPreserve.push h.fvarId!) goal oldIH newIH IHs t
mkLambdaFVars #[h] t'
let f' ← withLocalDecl `h .default (mkNot c') fun h => do
let f ← instantiateLambda f #[h]
let f' ← buildInductionBody motiveFVar fn toClear (toPreserve.push h.fvarId!) goal oldIH newIH IHs f
let f' ← buildInductionBody fn toClear (toPreserve.push h.fvarId!) goal oldIH newIH IHs f
mkLambdaFVars #[h] f'
let u ← getLevel goal
return mkApp5 (mkConst ``dite [u]) goal c' h' t' f'
Expand All @@ -483,7 +482,7 @@ partial def buildInductionBody (motiveFVar : FVarId) (fn : Expr) (toClear toPres
removeLamda alt fun oldIH' alt => do
forallBoundedTelescope expAltType (some 1) fun newIH' goal' => do
let #[newIH'] := newIH' | unreachable!
let alt' ← buildInductionBody motiveFVar fn (toClear.push newIH'.fvarId!) toPreserve goal' oldIH' newIH'.fvarId! IHs alt
let alt' ← buildInductionBody fn (toClear.push newIH'.fvarId!) toPreserve goal' oldIH' newIH'.fvarId! IHs alt
mkLambdaFVars #[newIH'] alt')
(onRemaining := fun _ => pure #[.fvar newIH])
return matcherApp'.toExpr
Expand All @@ -499,26 +498,26 @@ partial def buildInductionBody (motiveFVar : FVarId) (fn : Expr) (toClear toPres
(onParams := foldCalls fn oldIH)
(onMotive := fun xs _body => pure (absMotiveBody.beta (maskArray mask xs)))
(onAlt := fun expAltType alt => do
buildInductionBody motiveFVar fn toClear toPreserve expAltType oldIH newIH IHs alt)
buildInductionBody fn toClear toPreserve expAltType oldIH newIH IHs alt)
return matcherApp'.toExpr

if let .letE n t v b _ := e then
let IHs := IHs ++ (← collectIHs fn oldIH newIH v)
let t' ← foldCalls fn oldIH t
let v' ← foldCalls fn oldIH v
return ← withLetDecl n t' v' fun x => do
let b' ← buildInductionBody motiveFVar fn toClear toPreserve goal oldIH newIH IHs (b.instantiate1 x)
let b' ← buildInductionBody fn toClear toPreserve goal oldIH newIH IHs (b.instantiate1 x)
mkLetFVars #[x] b'

if let some (n, t, v, b) := e.letFun? then
let IHs := IHs ++ (← collectIHs fn oldIH newIH v)
let t' ← foldCalls fn oldIH t
let v' ← foldCalls fn oldIH v
return ← withLocalDecl n .default t' fun x => do
let b' ← buildInductionBody motiveFVar fn toClear toPreserve goal oldIH newIH IHs (b.instantiate1 x)
let b' ← buildInductionBody fn toClear toPreserve goal oldIH newIH IHs (b.instantiate1 x)
mkLetFun x v' b'

buildInductionCase motiveFVar fn oldIH newIH toClear toPreserve goal IHs e
buildInductionCase fn oldIH newIH toClear toPreserve goal IHs e

partial def findFixF {α} (name : Name) (e : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α := do
lambdaTelescope e fun params body => do
Expand Down Expand Up @@ -552,8 +551,6 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do
unless ← isDefEq arg params.back do
throwError "fixF application argument {arg} is not function argument "
let [argLevel, _motiveLevel] := f.constLevels! | unreachable!
-- logInfo body
-- mkFresh

let motiveType ← mkArrow argType (.sort levelZero)
withLocalDecl `motive .default motiveType fun motive => do
Expand All @@ -565,7 +562,7 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do
-- open body with the same arg
let body ← instantiateLambda body #[param]
removeLamda body fun oldIH body => do
let body' ← buildInductionBody motive.fvarId! fn #[genIH.fvarId!] #[] (.app motive param) oldIH genIH.fvarId! #[] body
let body' ← buildInductionBody fn #[genIH.fvarId!] #[] (.app motive param) oldIH genIH.fvarId! #[] body
if body'.containsFVar oldIH then
throwError m!"Did not fully eliminate {mkFVar oldIH} from induction principle body:{indentExpr body}"
mkLambdaFVars #[param, genIH] body'
Expand All @@ -574,9 +571,18 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do

let e' ← mkLambdaFVars #[params.back] e'
let mvars ← getMVarsNoDelayed e'
-- Using mkLambdaFVars on mvars directly does not reliably replace
-- the mvars with the parameter, in the presence of delayed assignemnts
-- But this code seems to work:
let mvars ← mvars.mapM fun mvar => do
let (_, mvar) ← mvar.revertAfter motive.fvarId!
pure mvar
-- Using `mkLambdaFVars` on mvars directly does not reliably replace
-- the mvars with the parameter, in the presence of delayed assignemnts.
-- Also `abstractMVars` does not handle delayed assignments correctly (as of now).
-- So instead we bring suitable fvars into scope and use `assign`; this handles
-- delayed assignemnts correctly.
-- NB: This idiom only works because
-- * we know that the `MVars` have the right local context (thanks to `mvarId.revertAfter`)
-- * the MVars are independent (so we don’t need to reorder them)
-- * we do no need the mvars in their unassigned form later
let e' ← Meta.withLocalDecls
(mvars.mapIdx (fun i mvar => (s!"case{i.val+1}", .default, (fun _ => mvar.getType))))
fun xs => do
Expand Down
12 changes: 6 additions & 6 deletions tests/lean/run/funind_expr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ info: Expr.typeCheck.induct (motive : Expr → Prop) (case1 : ∀ (a : Nat), mot
Expr.typeCheck b = Maybe.found Ty.nat h₂ →
Expr.typeCheck a = Maybe.found Ty.nat h₁ → motive a → motive b → motive (Expr.plus a b))
(case4 :
∀ (a b : Expr) (h₁ : HasType a Ty.bool) (h₂ : HasType b Ty.bool),
Expr.typeCheck b = Maybe.found Ty.bool h₂ →
Expr.typeCheck a = Maybe.found Ty.bool h₁ → motive a → motive b → motive (Expr.and a b))
(case5 :
∀ (a b : Expr),
(∀ (h₁ : HasType a Ty.nat) (h₂ : HasType b Ty.nat),
Expr.typeCheck a = Maybe.found Ty.nat h₁ → Expr.typeCheck b = Maybe.found Ty.nat h₂ → False) →
motive a → motive b → motive (Expr.plus a b))
(case5 :
∀ (a b : Expr) (h₁ : HasType a Ty.bool) (h₂ : HasType b Ty.bool),
Expr.typeCheck b = Maybe.found Ty.bool h₂ →
Expr.typeCheck a = Maybe.found Ty.bool h₁ → motive a → motive b → motive (Expr.and a b))
(case6 :
∀ (a b : Expr),
(∀ (h₁ : HasType a Ty.bool) (h₂ : HasType b Ty.bool),
Expand Down Expand Up @@ -94,7 +94,7 @@ theorem Expr.typeCheck_complete {e : Expr} : e.typeCheck = .unknown → ¬ HasTy
theorem Expr.typeCheck_complete' {e : Expr} : e.typeCheck = .unknown → ¬ HasType e ty := by
induction e using Expr.typeCheck.induct
all_goals simp [typeCheck]
case case3 | case4 => simp [*]
case case5 iha ihb | case6 iha ihb =>
case case3 | case5 => simp [*]
case case4 iha ihb | case6 iha ihb =>
intro ht; cases ht
next hnp h₁ h₂ => exact hnp h₁ h₂ (typeCheck_correct h₁ (iha · h₁)) (typeCheck_correct h₂ (ihb · h₂))
16 changes: 8 additions & 8 deletions tests/lean/run/funind_proof.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ end
derive_functional_induction replaceConst

/--
info: Term.replaceConst.induct (a b : String) (motive1 : Term → Prop) (motive2 : List Term → Prop)
(case1 : ∀ (a_1 : String), (a == a_1) = true → motive1 (const a_1))
(case2 : ∀ (a_1 : String), ¬(a == a_1) = true → motive1 (const a_1))
(case3 : ∀ (a : String) (cs : List Term), motive2 cs → motive1 (app a cs)) (case4 : motive2 [])
info: Term.replaceConst.induct (a b : String) (motive1 : Term → Prop) (motive2 : List Term → Prop) (case1 : motive2 [])
(case2 : ∀ (a_1 : String), (a == a_1) = true → motive1 (const a_1))
(case3 : ∀ (a_1 : String), ¬(a == a_1) = true → motive1 (const a_1))
(case4 : ∀ (a : String) (cs : List Term), motive2 cs → motive1 (app a cs))
(case5 : ∀ (c : Term) (cs : List Term), motive1 c → motive2 cs → motive2 (c :: cs)) (x : Term) : motive1 x
-/
#guard_msgs in
Expand All @@ -40,13 +40,13 @@ theorem numConsts_replaceConst (a b : String) (e : Term) : numConsts (replaceCon
apply replaceConst.induct
(motive1 := fun e => numConsts (replaceConst a b e) = numConsts e)
(motive2 := fun es => numConstsLst (replaceConstLst a b es) = numConstsLst es)
case case1 => intro c h; guard_hyp h :ₛ (a == c) = true; simp [replaceConst, numConsts, *]
case case2 => intro c h; guard_hyp h :ₛ ¬(a == c) = true; simp [replaceConst, numConsts, *]
case case3 =>
case case1 => simp [replaceConstLst, numConstsLst, *]
case case2 => intro c h; guard_hyp h :ₛ (a == c) = true; simp [replaceConst, numConsts, *]
case case3 => intro c h; guard_hyp h :ₛ ¬(a == c) = true; simp [replaceConst, numConsts, *]
case case4 =>
intros f cs ih
guard_hyp ih :ₛnumConstsLst (replaceConstLst a b cs) = numConstsLst cs
simp [replaceConst, numConsts, *]
case case4 => simp [replaceConstLst, numConstsLst, *]
case case5 =>
intro c cs ih₁ ih₂
guard_hyp ih₁ :ₛ numConsts (replaceConst a b c) = numConsts c
Expand Down
32 changes: 16 additions & 16 deletions tests/lean/run/funind_tests.lean
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ derive_functional_induction with_arg_refining_match1

/--
info: with_arg_refining_match1.induct (motive : Nat → Nat → Prop) (case1 : ∀ (fst : Nat), motive fst 0)
(case2 : ∀ (fst n : Nat), ¬fst = 0 → motive (fst - 1) n → motive fst (Nat.succ n))
(case3 : ∀ (n : Nat), motive 0 (Nat.succ n)) (x x : Nat) : motive x x
(case2 : ∀ (n : Nat), motive 0 (Nat.succ n))
(case3 : ∀ (fst n : Nat), ¬fst = 0 → motive (fst - 1) n → motive fst (Nat.succ n)) (x x : Nat) : motive x x
-/
#guard_msgs in
#check with_arg_refining_match1.induct
Expand All @@ -257,9 +257,9 @@ termination_by i
derive_functional_induction with_arg_refining_match2

/--
info: with_arg_refining_match2.induct (motive : Nat → Nat → Prop) (case1 : ∀ (fst : Nat), ¬fst = 0 → motive fst 0)
(case2 : ∀ (fst : Nat), ¬fst = 0 → ∀ (n : Nat), motive (fst - 1) n → motive fst (Nat.succ n))
(case3 : ∀ (snd : Nat), motive 0 snd) (x x : Nat) : motive x x
info: with_arg_refining_match2.induct (motive : Nat → Nat → Prop) (case1 : ∀ (snd : Nat), motive 0 snd)
(case2 : ∀ (fst : Nat), ¬fst = 0 → motive fst 0)
(case3 : ∀ (fst : Nat), ¬fst = 0 → ∀ (n : Nat), motive (fst - 1) n → motive fst (Nat.succ n)) (x x : Nat) : motive x x
-/
#guard_msgs in
#check with_arg_refining_match2.induct
Expand Down Expand Up @@ -621,17 +621,17 @@ end
derive_functional_induction even

/--
info: EvenOdd.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0)
(case2 : ∀ (n : Nat), motive2 n → motive1 (Nat.succ n)) (case3 : motive2 0)
(case4 : ∀ (n : Nat), motive1 n → motive2 (Nat.succ n)) (x : Nat) : motive1 x
info: EvenOdd.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : motive2 0)
(case3 : ∀ (n : Nat), motive2 n → motive1 (Nat.succ n)) (case4 : ∀ (n : Nat), motive1 n → motive2 (Nat.succ n))
(x : Nat) : motive1 x
-/
#guard_msgs in
#check even.induct

/--
info: EvenOdd.odd.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0)
(case2 : ∀ (n : Nat), motive2 n → motive1 (Nat.succ n)) (case3 : motive2 0)
(case4 : ∀ (n : Nat), motive1 n → motive2 (Nat.succ n)) (x : Nat) : motive2 x
info: EvenOdd.odd.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : motive2 0)
(case3 : ∀ (n : Nat), motive2 n → motive1 (Nat.succ n)) (case4 : ∀ (n : Nat), motive1 n → motive2 (Nat.succ n))
(x : Nat) : motive2 x
-/
#guard_msgs in
#check odd.induct
Expand Down Expand Up @@ -763,7 +763,7 @@ derive_functional_induction even._mutual

/--
info: CommandIdempotence.even._mutual.induct (motive : Nat ⊕' Nat → Prop) (case1 : motive (PSum.inl 0))
(case2 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl (Nat.succ n))) (case3 : motive (PSum.inr 0))
(case2 : motive (PSum.inr 0)) (case3 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl (Nat.succ n)))
(case4 : ∀ (n : Nat), motive (PSum.inl n) → motive (PSum.inr (Nat.succ n))) (x : Nat ⊕' Nat) : motive x
-/
#guard_msgs in
Expand All @@ -777,16 +777,16 @@ derive_functional_induction even

/--
info: CommandIdempotence.even._mutual.induct (motive : Nat ⊕' Nat → Prop) (case1 : motive (PSum.inl 0))
(case2 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl (Nat.succ n))) (case3 : motive (PSum.inr 0))
(case2 : motive (PSum.inr 0)) (case3 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl (Nat.succ n)))
(case4 : ∀ (n : Nat), motive (PSum.inl n) → motive (PSum.inr (Nat.succ n))) (x : Nat ⊕' Nat) : motive x
-/
#guard_msgs in
#check even._mutual.induct

/--
info: CommandIdempotence.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0)
(case2 : ∀ (n : Nat), motive2 n → motive1 (Nat.succ n)) (case3 : motive2 0)
(case4 : ∀ (n : Nat), motive1 n → motive2 (Nat.succ n)) (x : Nat) : motive1 x
info: CommandIdempotence.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : motive2 0)
(case3 : ∀ (n : Nat), motive2 n → motive1 (Nat.succ n)) (case4 : ∀ (n : Nat), motive1 n → motive2 (Nat.succ n))
(x : Nat) : motive1 x
-/
#guard_msgs in
#check even.induct
Expand Down

0 comments on commit d9f9e2c

Please sign in to comment.