diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index 7752ac82024f..3a9d32af64c9 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -1135,24 +1135,29 @@ private def throwLValError (e : Expr) (eType : Expr) (msg : MessageData) : TermE throwError "{msg}{indentExpr e}\nhas type{indentExpr eType}" /-- -`findMethod? env S fName`. -- If `env` contains `S ++ fName`, return `(S, S++fName)` -- Otherwise if `env` contains private name `prv` for `S ++ fName`, return `(S, prv)`, o -- Otherwise for each parent structure `S'` of `S`, we try `findMethod? env S' fname` +`findMethod? S fName` tries the following for each namespace `S'` in the resolution order for `S`: +- If `env` contains `S' ++ fName`, returns `(S', S' ++ fName)` +- Otherwise if `env` contains private name `prv` for `S' ++ fName`, returns `(S', prv)` -/ -private partial def findMethod? (env : Environment) (structName fieldName : Name) : Option (Name × Name) := - let fullName := structName ++ fieldName - match env.find? fullName with - | some _ => some (structName, fullName) - | none => +private partial def findMethod? (structName fieldName : Name) : MetaM (Option (Name × Name)) := do + let env ← getEnv + let find? structName' : MetaM (Option (Name × Name)) := do + let fullName := structName' ++ fieldName + if env.contains fullName then + return some (structName', fullName) let fullNamePrv := mkPrivateName env fullName - match env.find? fullNamePrv with - | some _ => some (structName, fullNamePrv) - | none => - if isStructure env structName then - (getStructureSubobjects env structName).findSome? fun parentStructName => findMethod? env parentStructName fieldName - else - none + if env.contains fullNamePrv then + return some (structName', fullNamePrv) + return none + -- Optimization: the first element of the resolution order is `structName`, + -- so we can skip computing the resolution order in the common case + -- of the name resolving in the `structName` namespace. + find? structName <||> do + let resolutionOrder ← if isStructure env structName then getStructureResolutionOrder structName else pure #[structName] + for h : i in [1:resolutionOrder.size] do + if let some res ← find? resolutionOrder[i] then + return res + return none /-- Return `some (structName', fullName)` if `structName ++ fieldName` is an alias for `fullName`, and @@ -1204,7 +1209,7 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L | some structName, LVal.fieldName _ fieldName _ _ => let env ← getEnv let searchEnv : Unit → TermElabM LValResolution := fun _ => do - if let some (baseStructName, fullName) := findMethod? env structName (.mkSimple fieldName) then + if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then return LValResolution.const baseStructName structName fullName else if let some (structName', fullName) := findMethodAlias? env structName (.mkSimple fieldName) then return LValResolution.const structName' structName' fullName @@ -1390,19 +1395,17 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp loop f lvals | LValResolution.projFn baseStructName structName fieldName => let f ← mkBaseProjections baseStructName structName f - if let some info := getFieldInfo? (← getEnv) baseStructName fieldName then - if isPrivateNameFromImportedModule (← getEnv) info.projFn then - throwError "field '{fieldName}' from structure '{structName}' is private" - let projFn ← mkConst info.projFn - let projFn ← addProjTermInfo lval.getRef projFn - if lvals.isEmpty then - let namedArgs ← addNamedArg namedArgs { name := `self, val := Arg.expr f, suppressDeps := true } - elabAppArgs projFn namedArgs args expectedType? explicit ellipsis - else - let f ← elabAppArgs projFn #[{ name := `self, val := Arg.expr f, suppressDeps := true }] #[] (expectedType? := none) (explicit := false) (ellipsis := false) - loop f lvals + let some info := getFieldInfo? (← getEnv) baseStructName fieldName | unreachable! + if isPrivateNameFromImportedModule (← getEnv) info.projFn then + throwError "field '{fieldName}' from structure '{structName}' is private" + let projFn ← mkConst info.projFn + let projFn ← addProjTermInfo lval.getRef projFn + if lvals.isEmpty then + let namedArgs ← addNamedArg namedArgs { name := `self, val := Arg.expr f, suppressDeps := true } + elabAppArgs projFn namedArgs args expectedType? explicit ellipsis else - unreachable! + let f ← elabAppArgs projFn #[{ name := `self, val := Arg.expr f, suppressDeps := true }] #[] (expectedType? := none) (explicit := false) (ellipsis := false) + loop f lvals | LValResolution.const baseStructName structName constName => let f ← if baseStructName != structName then mkBaseProjections baseStructName structName f else pure f let projFn ← mkConst constName diff --git a/src/Lean/Elab/Structure.lean b/src/Lean/Elab/Structure.lean index 4a8db09ad52d..be5c1f8f0970 100644 --- a/src/Lean/Elab/Structure.lean +++ b/src/Lean/Elab/Structure.lean @@ -22,7 +22,12 @@ namespace Lean.Elab.Command register_builtin_option structureDiamondWarning : Bool := { defValue := false - descr := "enable/disable warning messages for structure diamonds" + descr := "if true, enable warnings when a structure has diamond inheritance" +} + +register_builtin_option structure.strictResolutionOrder : Bool := { + defValue := false + descr := "if true, require a strict resolution order for structures" } open Meta @@ -943,6 +948,23 @@ private def mkInductiveType (view : StructView) (indFVar : Expr) (levelNames : L instantiateMVars (← mkForallFVars params type) return { name := view.declName, type := ← instantiateMVars type, ctors := [{ ctor with type := ← instantiateMVars ctorType }] } +/-- +Precomputes the structure's resolution order. +Option `structure.strictResolutionOrder` controls whether to create a warning if the C3 algorithm failed. +-/ +private def checkResolutionOrder (structName : Name) : TermElabM Unit := do + let resolutionOrderResult ← computeStructureResolutionOrder structName (relaxed := !structure.strictResolutionOrder.get (← getOptions)) + trace[Elab.structure.resolutionOrder] "computed resolution order: {resolutionOrderResult.resolutionOrder}" + unless resolutionOrderResult.conflicts.isEmpty do + let mut defects : List MessageData := [] + for conflict in resolutionOrderResult.conflicts do + let parentKind direct := if direct then "parent" else "indirect parent" + let conflicts := conflict.conflicts.map fun (isDirect, name) => + m!"{parentKind isDirect} '{MessageData.ofConstName name}'" + defects := m!"- {parentKind conflict.isDirectParent} '{MessageData.ofConstName conflict.badParent}' \ + must come after {MessageData.andList conflicts.toList}" :: defects + logWarning m!"failed to compute strict resolution order:\n{MessageData.joinSep defects.reverse "\n"}" + def mkStructureDecl (vars : Array Expr) (view : StructView) : TermElabM Unit := Term.withoutSavingRecAppSyntax do let scopeLevelNames ← Term.getLevelNames let isUnsafe := view.modifiers.isUnsafe @@ -1008,6 +1030,8 @@ def mkStructureDecl (vars : Array Expr) (view : StructView) : TermElabM Unit := else mkCoercionToCopiedParent levelParams params view parent.structName parent.type setStructureParents view.declName parentInfos + checkResolutionOrder view.declName + let lctx ← getLCtx /- The `lctx` and `defaultAuxDecls` are used to create the auxiliary "default value" declarations The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/ @@ -1045,6 +1069,8 @@ def elabStructure (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := pure view elabStructureViewPostprocessing view -builtin_initialize registerTraceClass `Elab.structure +builtin_initialize + registerTraceClass `Elab.structure + registerTraceClass `Elab.structure.resolutionOrder end Lean.Elab.Command diff --git a/src/Lean/Server/Completion.lean b/src/Lean/Server/Completion.lean index 83a38dac54aa..75623140c28c 100644 --- a/src/Lean/Server/Completion.lean +++ b/src/Lean/Server/Completion.lean @@ -535,7 +535,7 @@ where let .const typeName _ := type.getAppFn | return () modify fun s => s.insert typeName if isStructure (← getEnv) typeName then - for parentName in getAllParentStructures (← getEnv) typeName do + for parentName in (← getAllParentStructures typeName) do modify fun s => s.insert parentName let some type ← unfoldeDefinitionGuarded? type | return () visit type diff --git a/src/Lean/Structure.lean b/src/Lean/Structure.lean index 089d7f7c78fd..817da6c82be9 100644 --- a/src/Lean/Structure.lean +++ b/src/Lean/Structure.lean @@ -154,6 +154,7 @@ def getStructureCtor (env : Environment) (constName : Name) : ConstructorVal := def getStructureFields (env : Environment) (structName : Name) : Array Name := (getStructureInfo env structName).fieldNames +/-- Get the `StructureFieldInfo` for the given direct field of the structure. -/ def getFieldInfo? (env : Environment) (structName : Name) (fieldName : Name) : Option StructureFieldInfo := if let some info := getStructureInfo? env structName then info.fieldInfo.binSearch { fieldName := fieldName, projFn := default, subobject? := none, binderInfo := default } StructureFieldInfo.lt @@ -180,21 +181,7 @@ If a direct parent cannot itself be represented as a subobject, sometimes one of its parents (or one of their parents, etc.) can. -/ def getStructureSubobjects (env : Environment) (structName : Name) : Array Name := - let fieldNames := getStructureFields env structName; - fieldNames.foldl (init := #[]) fun acc fieldName => - match isSubobjectField? env structName fieldName with - | some parentStructName => acc.push parentStructName - | none => acc - --- TODO: use actual parents, not just subobjects. -/-- Return all parent structures -/ -partial def getAllParentStructures (env : Environment) (structName : Name) : Array Name := - visit structName |>.run #[] |>.2 -where - visit (structName : Name) : StateT (Array Name) Id Unit := do - for p in getStructureSubobjects env structName do - modify fun s => s.push p - visit p + (getStructureFields env structName).filterMap (isSubobjectField? env structName) /-- Return the name of the structure that contains the field relative to structure `structName`. @@ -269,18 +256,23 @@ partial def getPathToBaseStructureAux (env : Environment) (baseStructName : Name if baseStructName == structName then some path.reverse else - let fieldNames := getStructureFields env structName; - fieldNames.findSome? fun fieldName => - match isSubobjectField? env structName fieldName with - | none => none - | some parentStructName => - match getProjFnForField? env structName fieldName with - | none => none - | some projFn => getPathToBaseStructureAux env baseStructName parentStructName (projFn :: path) + if let some info := getStructureInfo? env structName then + -- Prefer subobject projections + (info.fieldInfo.findSome? fun field => + match field.subobject? with + | none => none + | some parentStructName => getPathToBaseStructureAux env baseStructName parentStructName (field.projFn :: path)) + -- Otherwise, consider other parents + <|> info.parentInfo.findSome? fun parent => + if parent.subobject then + none + else + getPathToBaseStructureAux env baseStructName parent.structName (parent.projFn :: path) + else none /-- -If `baseStructName` is an ancestor structure for `structName`, then returns a sequence of projection functions -to go from `structName` to `baseStructName`. +If `baseStructName` is an ancestor structure for `structName`, then return a sequence of projection functions +to go from `structName` to `baseStructName`. Returns `[]` if `baseStructName == structName`. -/ def getPathToBaseStructure? (env : Environment) (baseStructName : Name) (structName : Name) : Option (List Name) := getPathToBaseStructureAux env baseStructName structName [] @@ -315,4 +307,132 @@ def getStructureLikeNumFields (env : Environment) (constName : Name) : Nat := | _ => 0 | _ => 0 +/-! +### Resolution orders + +This section is for computations to determine which namespaces to visit when resolving field notation. +While the set of namespaces is clear (after a structure's namespace, it is the namespaces for *all* parents), +the question is the order to visit them in. + +We use the C3 superclass linearization algorithm from Barrett et al., "A Monotonic Superclass Linearization for Dylan", OOPSLA 1996. +For reference, the C3 linearization is known as the "method resolution order" (MRO) [in Python](https://docs.python.org/3/howto/mro.html). + +The basic idea is that we want to find a resolution order with the following property: +For each structure `S` that appears in the resolution order, if its direct parents are `P₁ .. Pₙ`, +then `S P₁ ... Pₙ` forms a subsequence of the resolution order. + +This has a stability property where if `S` extends `S'`, then the resolution order of `S` contains the resolution order of `S'` as a subsequence. +It also has the key property that if `P` and `P'` are parents of `S`, then we visit `P` and `P'` before we visit the shared parents of `P` and `P'`. + +Finding such a resolution order might not be possible. +Still, we can enable a relaxation of the algorithm by ignoring one or more parent resolution orders, starting from the end. + +In Hivert and Thiéry "Controlling the C3 super class linearization algorithm for large hierarchies of classes" +https://arxiv.org/pdf/2401.12740 the authors discuss how in SageMath, which has thousands of classes, +C3 can be difficult to control, since maintaining correct direct parent orders is a burden. +They give suggestions that have worked for the SageMath project. +We may consider introducing an environment extension with ordering hints to help guide the algorithm if we see similar difficulties. +-/ + +structure StructureResolutionState where + resolutions : PHashMap Name (Array Name) := {} + deriving Inhabited + +/-- +We use an environment extension to cache resolution orders. +These are not expensive to compute, but worth caching, and we save olean storage space. +-/ +builtin_initialize structureResolutionExt : EnvExtension StructureResolutionState ← + registerEnvExtension (pure {}) + +/-- Gets the resolution order if it has already been cached. -/ +private def getStructureResolutionOrder? (env : Environment) (structName : Name) : Option (Array Name) := + (structureResolutionExt.getState env).resolutions.find? structName + +/-- Caches a structure's resolution order. -/ +private def setStructureResolutionOrder [MonadEnv m] (structName : Name) (resolutionOrder : Array Name) : m Unit := + modifyEnv fun env => structureResolutionExt.modifyState env fun s => + { s with resolutions := s.resolutions.insert structName resolutionOrder } + +/-- "The `badParent` must come after the `conflicts`. -/ +structure StructureResolutionOrderConflict where + isDirectParent : Bool + badParent : Name + /-- Conflicts that must come before `badParent`. The flag is whether it is a direct parent. -/ + conflicts : Array (Bool × Name) + deriving Inhabited + +structure StructureResolutionOrderResult where + resolutionOrder : Array Name + conflicts : Array StructureResolutionOrderConflict := #[] + deriving Inhabited + +/-- +Computes and caches the C3 linearization. Assumes parents have already been set with `setStructureParents`. +If `relaxed` is false, then if the linearization cannot be computed, conflicts are recorded in the return value. +-/ +partial def computeStructureResolutionOrder [Monad m] [MonadEnv m] + (structName : Name) (relaxed : Bool) : m StructureResolutionOrderResult := do + let env ← getEnv + if let some resOrder := getStructureResolutionOrder? env structName then + return { resolutionOrder := resOrder } + let parentNames := getStructureParentInfo env structName |>.map (·.structName) + -- Don't be strict about parents: if they were supposed to be checked, they were already checked. + let parentResOrders ← parentNames.mapM fun parentName => return (← computeStructureResolutionOrder parentName true).resolutionOrder + + -- `resOrders` contains the resolution orders to merge. + -- The parent list is inserted as a pseudo resolution order to ensure immediate parents come out in order, + -- and it is added first to be the primary ordering constraint when there are ordering errors. + let mut resOrders := parentResOrders.insertAt 0 parentNames |>.filter (!·.isEmpty) + + let mut resOrder : Array Name := #[structName] + let mut defects : Array StructureResolutionOrderConflict := #[] + -- Every iteration of the loop, the sum of the sizes of the arrays in `resOrders` decreases by at least one, + -- so it terminates. + while !resOrders.isEmpty do + let (good, name) ← selectParent resOrders + + unless good || relaxed do + let conflicts := resOrders |>.filter (·[1:].any (· == name)) |>.map (·[0]!) |>.qsort Name.lt |>.eraseReps + defects := defects.push { + isDirectParent := parentNames.contains name + badParent := name + conflicts := conflicts.map fun c => (parentNames.contains c, c) + } + + resOrder := resOrder.push name + resOrders := resOrders + |>.map (fun resOrder => resOrder.filter (· != name)) + |>.filter (!·.isEmpty) + + setStructureResolutionOrder structName resOrder + return { resolutionOrder := resOrder, conflicts := defects } +where + selectParent (resOrders : Array (Array Name)) : m (Bool × Name) := do + -- Assumption: every resOrder is nonempty. + -- `n'` is for relaxation, to stop paying attention to end of `resOrders` when finding a good parent. + for n' in [0 : resOrders.size] do + let hi := resOrders.size - n' + for i in [0 : hi] do + let parent := resOrders[i]![0]! + let consistent resOrder := resOrder[1:].all (· != parent) + if resOrders[0:i].all consistent && resOrders[i+1:hi].all consistent then + return (n' == 0, parent) + -- unreachable, but correct default: + return (false, resOrders[0]![0]!) + +/-- +Gets the resolution order for a structure. +-/ +def getStructureResolutionOrder [Monad m] [MonadEnv m] + (structName : Name) : m (Array Name) := + (·.resolutionOrder) <$> computeStructureResolutionOrder structName (relaxed := true) + +/-- +Returns the transitive closure of all parent structures of the structure. +This is the same as `Lean.getStructureResolutionOrder` but without including `structName`. +-/ +partial def getAllParentStructures [Monad m] [MonadEnv m] (structName : Name) : m (Array Name) := + (·.erase structName) <$> getStructureResolutionOrder structName + end Lean diff --git a/tests/lean/run/3467.lean b/tests/lean/run/3467.lean new file mode 100644 index 000000000000..d3df62417278 --- /dev/null +++ b/tests/lean/run/3467.lean @@ -0,0 +1,139 @@ +import Lean +/-! +# Tests for structure resolution order. + +https://github.com/leanprover/lean4/issues/3467 +https://github.com/leanprover/lean4/issues/1881 +-/ + +/-! +Basic diamond +-/ + +set_option structure.strictResolutionOrder true +set_option trace.Elab.structure.resolutionOrder true + +/-- info: [Elab.structure.resolutionOrder] computed resolution order: [A] -/ +#guard_msgs in structure A +/-- info: [Elab.structure.resolutionOrder] computed resolution order: [B, A] -/ +#guard_msgs in structure B extends A +/-- info: [Elab.structure.resolutionOrder] computed resolution order: [C, A] -/ +#guard_msgs in structure C extends A +/-- info: [Elab.structure.resolutionOrder] computed resolution order: [D, B, C, A] -/ +#guard_msgs in structure D extends B, C + +def A.x (a : A) : Bool := default +def B.x (b : B) : Nat := default +def A.y (c : A) : Bool := default +def C.y (d : C) : Nat := default + +variable (a : A) (b : B) (c : C) (d : D) + +/-- info: a.x : Bool -/ +#guard_msgs in #check a.x +/-- info: b.x : Nat -/ +#guard_msgs in #check b.x +/-- info: c.x : Bool -/ +#guard_msgs in #check c.x +/-- info: d.x : Nat -/ +#guard_msgs in #check d.x +/-- info: a.y : Bool -/ +#guard_msgs in #check a.y +/-- info: b.y : Bool -/ +#guard_msgs in #check b.y +/-- info: c.y : Nat -/ +#guard_msgs in #check c.y +/-- info: d.toC.y : Nat -/ +#guard_msgs in #check d.y + + +/-! +Example resolution order failure +-/ + +/-- +warning: failed to compute strict resolution order: +- parent 'B' must come after parent 'D' +--- +info: [Elab.structure.resolutionOrder] computed resolution order: [D', B, D, C, A] +-/ +#guard_msgs in +structure D' extends B, D + + +/-! +Example from issue 3467. +-/ + +namespace Issue3467 + +/-- info: [Elab.structure.resolutionOrder] computed resolution order: [Issue3467.X] -/ +#guard_msgs in +structure X where + base : Nat + +/-- info: [Elab.structure.resolutionOrder] computed resolution order: [Issue3467.A, Issue3467.X] -/ +#guard_msgs in +structure A extends X where + countA : Nat + +/-- info: [Elab.structure.resolutionOrder] computed resolution order: [Issue3467.B, Issue3467.X] -/ +#guard_msgs in +structure B extends X where + countB : Nat + +namespace A + +def getTwiceCountA (a : A) := a.countA * 2 + +end A + +namespace B + +def getTwiceCountB (b : B) := b.countB * 2 + +end B + +/-- +info: [Elab.structure.resolutionOrder] computed resolution order: [Issue3467.C, Issue3467.A, Issue3467.B, Issue3467.X] +-/ +#guard_msgs in +structure C extends A, B + +def getCounts (c : C) := + c.countA + c.countB + +def getTwiceCounts (c : C) := + c.getTwiceCountA + c.getTwiceCountB +-- ^^^^ used to fail to resolve + +end Issue3467 + + +namespace Issue1881 + +/-- info: [Elab.structure.resolutionOrder] computed resolution order: [Issue1881.Foo1] -/ +#guard_msgs in +structure Foo1 where + a : Nat + b : Nat + +/-- info: [Elab.structure.resolutionOrder] computed resolution order: [Issue1881.Foo2] -/ +#guard_msgs in +structure Foo2 where + a : Nat + c : Nat + +/-- +info: [Elab.structure.resolutionOrder] computed resolution order: [Issue1881.Foo3, Issue1881.Foo1, Issue1881.Foo2] +-/ +#guard_msgs in +structure Foo3 extends Foo1, Foo2 where + d : Nat + +def Foo1.bar1 (_ : Foo1) : Nat := 0 +def Foo2.bar2 (_ : Foo2) : Nat := 1 +example (x : Foo3) := x.bar1 -- works +example (x : Foo3) := x.bar2 -- now works + +end Issue1881 diff --git a/tests/lean/run/structure.lean b/tests/lean/run/structure.lean index 4731339e6f89..fc2432600519 100644 --- a/tests/lean/run/structure.lean +++ b/tests/lean/run/structure.lean @@ -28,7 +28,7 @@ info: #[const2ModIdx, constants, extensions, extraConstNames, header] #[toS2, toS1, x, y, z, toS3, w, s] (some [S4.toS2, S2.toS1]) #[S2, S3] -#[S2, S1, S3] +#[S2, S3, S1] -/ #guard_msgs in #eval show CoreM Unit from do @@ -51,7 +51,7 @@ info: #[const2ModIdx, constants, extensions, extraConstNames, header] IO.println (getStructureFieldsFlattened env `S4) IO.println (getPathToBaseStructure? env `S1 `S4) IO.println (getStructureSubobjects env `S4) - IO.println (getAllParentStructures env `S4) + IO.println (← getAllParentStructures `S4) pure () def dumpStructInfo (structName : Name) : CoreM Unit := do