Skip to content

Commit

Permalink
feat: parity between structure instance notation and where notation (
Browse files Browse the repository at this point in the history
…leanprover#6165)

This PR modifies structure instance notation and `where` notation to use
the same notation for fields. Structure instance notation now admits
binders, type ascriptions, and equations, and `where` notation admits
full structure lvals. Examples of these for structure instance notation:
```lean
structure PosFun where
  f : Nat → Nat
  pos : ∀ n, 0 < f n

def p : PosFun :=
  { f n := n + 1
    pos := by simp }

def p' : PosFun :=
  { f | 0 => 1
      | n + 1 => n + 1
    pos := by rintro (_|_) <;> simp }
```
Just like for the structure `where` notation, a field `f x y z : ty :=
val` expands to `f := fun x y z => (val : ty)`. The type ascription is
optional.

The PR also is setting things up for future expansion. Pending some
discussion, in the future structure/`where` notation could have have
embedded `where` clauses; rather than `{ a := { x := 1, y := z } }` one
could write `{ a where x := 1; y := z }`.
  • Loading branch information
kmill authored and JovanGerb committed Jan 21, 2025
1 parent d5b9b2b commit c398f63
Show file tree
Hide file tree
Showing 13 changed files with 328 additions and 174 deletions.
3 changes: 1 addition & 2 deletions src/Init/Notation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ For example, `let x ← e` is a `doElem`, and a `do` block consists of a list of
def doElem : Category := {}

/-- `structInstFieldDecl` is the syntax category for value declarations for fields in structure instance notation.
For example, the `:= 1` and `where a := 3` in `{ x := 1, y where a := 3 }` are in the `structInstFieldDecl` class.
This category is necessary because structure instance notation is recursive due to the `x where ...` field notation. -/
For example, the `:= 1` and `| 0 => 0 | n + 1 => n` in `{ x := 1, f | 0 => 0 | n + 1 => n }` are in the `structInstFieldDecl` class. -/
def structInstFieldDecl : Category := {}

/-- `level` is a builtin syntax category for universe levels.
Expand Down
77 changes: 31 additions & 46 deletions src/Lean/Elab/MutualDef.lean
Original file line number Diff line number Diff line change
Expand Up @@ -282,52 +282,37 @@ private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (
k fvars
loop 0 #[]

private def expandWhereStructInst : Macro
| whereStx@`(Parser.Command.whereStructInst|where%$whereTk $[$decls:letDecl];* $[$whereDecls?:whereDecls]?) => do
let letIdDecls ← decls.mapM fun stx => match stx with
| `(letDecl|$_decl:letPatDecl) => Macro.throwErrorAt stx "patterns are not allowed here"
| `(letDecl|$decl:letEqnsDecl) => expandLetEqnsDecl decl (useExplicit := false)
| `(letDecl|$decl:letIdDecl) => pure decl
| _ => Macro.throwUnsupported
let structInstFields ← letIdDecls.mapM fun
| stx@`(letIdDecl|$id:ident $binders* $[: $ty?]? := $val) => withRef stx do
let mut val := val
if let some ty := ty? then
val ← `(($val : $ty))
-- HACK: this produces invalid syntax, but the fun elaborator supports letIdBinders as well
have : Coe (TSyntax ``letIdBinder) (TSyntax ``funBinder) := ⟨(⟨·⟩)⟩
val ← if binders.size > 0 then `(fun $binders* => $val) else pure val
`(structInstField|$id:ident := $val)
| stx@`(letIdDecl|_ $_* $[: $_]? := $_) => Macro.throwErrorAt stx "'_' is not allowed here"
| _ => Macro.throwUnsupported

let startOfStructureTkInfo : SourceInfo :=
match whereTk.getPos? with
| some pos => .synthetic pos ⟨pos.byteIdx + 1true
| none => .none
-- Position the closing `}` at the end of the trailing whitespace of `where $[$_:letDecl];*`.
-- We need an accurate range of the generated structure instance in the generated `TermInfo`
-- so that we can determine the expected type in structure field completion.
let structureStxTailInfo :=
whereStx[1].getTailInfo?
<|> whereStx[0].getTailInfo?
let endOfStructureTkInfo : SourceInfo :=
match structureStxTailInfo with
| some (SourceInfo.original _ _ trailing _) =>
let tokenPos := trailing.str.prev trailing.stopPos
let tokenEndPos := trailing.stopPos
.synthetic tokenPos tokenEndPos true
| _ => .none

let body ← `(structInst| { $structInstFields,* })
let body := body.raw.setInfo <|
match startOfStructureTkInfo.getPos?, endOfStructureTkInfo.getTailPos? with
| some startPos, some endPos => .synthetic startPos endPos true
| _, _ => .none
match whereDecls? with
| some whereDecls => expandWhereDecls whereDecls body
| none => return body
| _ => Macro.throwUnsupported
private def expandWhereStructInst : Macro := fun whereStx => do
let whereTk := whereStx[0]
let structInstFields : TSyntaxArray ``Parser.Term.structInstField := .mk whereStx[1][0].getSepArgs
let whereDecls? := whereStx[2].getOptional?

let startOfStructureTkInfo : SourceInfo :=
match whereTk.getPos? with
| some pos => .synthetic pos ⟨pos.byteIdx + 1true
| none => .none
-- Position the closing `}` at the end of the trailing whitespace of `where $[$_:letDecl];*`.
-- We need an accurate range of the generated structure instance in the generated `TermInfo`
-- so that we can determine the expected type in structure field completion.
let structureStxTailInfo :=
whereStx[1].getTailInfo?
<|> whereStx[0].getTailInfo?
let endOfStructureTkInfo : SourceInfo :=
match structureStxTailInfo with
| some (SourceInfo.original _ _ trailing _) =>
let tokenPos := trailing.str.prev trailing.stopPos
let tokenEndPos := trailing.stopPos
.synthetic tokenPos tokenEndPos true
| _ => .none

let body ← `(structInst| { $structInstFields,* })
let body := body.raw.setInfo <|
match startOfStructureTkInfo.getPos?, endOfStructureTkInfo.getTailPos? with
| some startPos, some endPos => .synthetic startPos endPos true
| _, _ => .none
match whereDecls? with
| some whereDecls => expandWhereDecls whereDecls body
| none => return body

/-
Recall that
Expand Down
15 changes: 10 additions & 5 deletions src/Lean/Elab/PatternVar.lean
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,16 @@ partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroSc
| `({ $[$srcs?,* with]? $fields,* $[..%$ell?]? $[: $ty?]? }) =>
if let some srcs := srcs? then
throwErrorAt (mkNullNode srcs) "invalid struct instance pattern, 'with' is not allowed in patterns"
let fields ← fields.getElems.mapM fun
| `(Parser.Term.structInstField| $lval:structInstLVal := $val) => do
let newVal ← collect val
`(Parser.Term.structInstField| $lval:structInstLVal := $newVal)
| _ => throwInvalidPattern -- `structInstFieldAbbrev` should be expanded at this point
-- TODO(kmill) restore this
-- let fields ← fields.getElems.mapM fun
-- | `(Parser.Term.structInstField| $lval:structInstLVal := $val) => do
-- let newVal ← collect val
-- `(Parser.Term.structInstField| $lval:structInstLVal := $newVal)
-- | _ => throwInvalidPattern -- `structInstFieldAbbrev` should be expanded at this point
let fields ← fields.getElems.mapM fun field => do
let field := field.raw
let val ← collect field[1][2][1]
pure <| field.setArg 1 <| field[1].setArg 2 <| field[1][2].setArg 1 val
`({ $[$srcs?,* with]? $fields,* $[..%$ell?]? $[: $ty?]? })
| _ => throwInvalidPattern

Expand Down
135 changes: 106 additions & 29 deletions src/Lean/Elab/StructInst.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,32 @@ open Meta
open TSyntax.Compat

/-!
Recall that structure instances are of the form:
```
"{" >> optional (atomic (sepBy1 termParser ", " >> " with "))
>> manyIndent (group ((structInstFieldAbbrev <|> structInstField) >> optional ", "))
Recall that structure instances are (after removing parsing and pretty printing hints):
```lean
def structInst := leading_parser
"{ " >> optional (sepBy1 termParser ", " >> " with ")
>> structInstFields (sepByIndent structInstField ", " (allowTrailingSep := true))
>> optEllipsis
>> optional (" : " >> termParser)
>> " }"
>> optional (" : " >> termParser) >> " }"
def structInstField := leading_parser
structInstLVal >> optional (many structInstFieldBinder >> optType >> structInstFieldDecl)
@[builtin_structInstFieldDecl_parser]
def structInstFieldDef := leading_parser
" := " >> termParser
@[builtin_structInstFieldDecl_parser]
def structInstFieldEqns := leading_parser
matchAlts
def structInstWhereBody := leading_parser
structInstFields (sepByIndent structInstField "; " (allowTrailingSep := true))
@[builtin_structInstFieldDecl_parser]
def structInstFieldWhere := leading_parser
"where" >> structInstWhereBody
```
-/

Expand All @@ -54,21 +73,74 @@ Structure instance notation makes use of the expected type.
let stxNew := stx.setArg 4 mkNullNode
`(($stxNew : $expected))

def mkStructInstField (lval : TSyntax ``Parser.Term.structInstLVal) (binders : TSyntaxArray ``Parser.Term.structInstFieldBinder)
(type? : Option Term) (val : Term) : MacroM Term := do
let mut val := val
if let some type := type? then
val ← `(($val : $type))
if !binders.isEmpty then
-- HACK: this produces invalid syntax, but the fun elaborator supports structInstFieldBinder as well
val ← `(fun $binders* => $val)
-- `(Parser.Term.structInstField| $lval := $val)
return mkNode ``Parser.Term.structInstField
#[lval, mkNullNode #[mkNullNode, mkNullNode, mkNode ``Parser.Term.structInstFieldDef #[mkAtom " := ", val]]]

/--
Expands field abbreviation notation.
Example: `{ x, y := 0 }` expands to `{ x := x, y := 0 }`.
Takes an arbitrary `structInstField` and expands it to be a `structInstFieldDef` without any binders or type ascription.
-/
@[builtin_macro Lean.Parser.Term.structInst] def expandStructInstFieldAbbrev : Macro
| `({ $[$srcs,* with]? $fields,* $[..%$ell]? $[: $ty]? }) =>
if fields.getElems.raw.any (·.getKind == ``Lean.Parser.Term.structInstFieldAbbrev) then do
let fieldsNew ← fields.getElems.mapM fun
| `(Parser.Term.structInstFieldAbbrev| $id:ident) =>
`(Parser.Term.structInstField| $id:ident := $id:ident)
| field => return field
`({ $[$srcs,* with]? $fieldsNew,* $[..%$ell]? $[: $ty]? })
private def expandStructInstField (stx : Syntax) : MacroM (Option Syntax) := withRef stx do
if stx.isOfKind `Lean.Parser.Term.structInstField && stx.getNumArgs == 3 then
-- old syntax
let lval : TSyntax ``Parser.Term.structInstLVal := stx[0]
let val : Term := stx[2]
mkStructInstField lval #[] none val
else if stx.isOfKind `Lean.Parser.Term.structInstFieldAbbrev then
-- old syntax
let id : Ident := stx[0]
let lval ← `(Parser.Term.structInstLVal| $id:ident)
mkStructInstField lval #[] none id
else if stx.isOfKind ``Parser.Term.structInstField then
let lval := stx[0]
if stx[1].getNumArgs > 0 then
let binders := stx[1][0].getArgs
let ty? := match stx[1][1] with | `(Parser.Term.optTypeForStructInst| $[: $ty?]?) => ty? | _ => none
let decl := stx[1][2]
match decl with
| `(Parser.Term.structInstFieldDef| := $val) =>
if binders.isEmpty && ty?.isNone then
return none
else
mkStructInstField lval binders ty? val
| `(Parser.Term.structInstFieldEqns| $alts:matchAlts) =>
let val ← expandMatchAltsIntoMatch stx alts (useExplicit := false)
mkStructInstField lval binders ty? val
| _ => Macro.throwUnsupported
else
Macro.throwUnsupported
| _ => Macro.throwUnsupported
-- Abbreviation
match lval with
| `(Parser.Term.structInstLVal| $id:ident) =>
mkStructInstField lval #[] none id
| _ =>
Macro.throwErrorAt lval "unsupported structure instance field abbreviation, expecting identifier"
else
Macro.throwUnsupported

/--
Expands fields.
* Abbrevations. Example: `{ x }` expands to `{ x := x }`.
* Equations. Example: `{ f | 0 => 0 | n + 1 => n }` expands to `{ f := fun x => match x with | 0 => 0 | n + 1 => n }`.
* `where`. Example: `{ s where x := 1 }` expands to `{ s := { x := 1 }}`.
* Binders and types. Example: `{ f n : Nat := n + 1 }` expands to `{ f := fun n => (n + 1 : Nat) }`.
-/
@[builtin_macro Lean.Parser.Term.structInst] def expandStructInstFields : Macro | stx => do
let structInstFields := stx[2]
let fields := structInstFields[0].getSepArgs
let fields? ← fields.mapM expandStructInstField
if fields?.all (·.isNone) then
Macro.throwUnsupported
let fields := fields?.zipWith fields Option.getD
let structInstFields := structInstFields.setArg 0 <| Syntax.mkSep fields (mkAtomFrom stx ", ")
return stx.setArg 2 structInstFields

/--
If `stx` is of the form `{ s₁, ..., sₙ with ... }` and `sᵢ` is not a local variable,
Expand Down Expand Up @@ -187,12 +259,13 @@ def structInstArrayRef := leading_parser "[" >> termParser >>"]"
-/
private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do
let s? ← stx[2][0].getSepArgs.foldlM (init := none) fun s? arg => do
/- arg is of the form `structInstFieldAbbrev <|> structInstField` -/
if arg.getKind == ``Lean.Parser.Term.structInstField then
/- Remark: the syntax for `structInstField` is
/- arg is of the form `structInstField`. It should be macro expanded at this point, but we make sure it's the case. -/
if arg[1][2].getKind == ``Lean.Parser.Term.structInstFieldDef then
/- Remark: the syntax for `structInstField` after macro expansion is
```
def structInstLVal := leading_parser (ident <|> numLit <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef)
def structInstField := leading_parser structInstLVal >> " := " >> termParser
def structInstFieldDef := leading_parser
structInstLVal >> group (null >> null >> group (" := " >> termParser))
```
-/
let lval := arg[0]
Expand Down Expand Up @@ -235,7 +308,7 @@ private def elabModifyOp (stx modifyOp : Syntax) (sources : Array ExplicitSource
withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
let rest := modifyOp[0][1]
if rest.isNone then
cont modifyOp[2]
cont modifyOp[1][2][1]
else
let s ← `(s)
let valFirst := rest[0]
Expand Down Expand Up @@ -412,7 +485,7 @@ Converts a `Field StructInstView` back into syntax. Used to construct synthetic
private def Field.toSyntax : Field → Syntax
| field =>
let stx := field.ref
let stx := stx.setArg 2 field.val.toSyntax
let stx := stx.setArg 1 <| stx[1].setArg 2 <| stx[1][2].setArg 1 field.val.toSyntax
match field.lhs with
| first::rest => stx.setArg 0 <| mkNullNode #[first.toSyntax true, mkNullNode <| rest.toArray.map (FieldLHS.toSyntax false) ]
| _ => unreachable!
Expand All @@ -428,7 +501,7 @@ private def toFieldLHS (stx : Syntax) : MacroM FieldLHS :=
return FieldLHS.fieldName stx stx.getId.eraseMacroScopes
else match stx.isFieldIdx? with
| some idx => return FieldLHS.fieldIndex stx idx
| none => Macro.throwError "unexpected structure syntax"
| none => Macro.throwErrorAt stx "unexpected structure syntax"

/--
Creates a structure instance view from structure instance notation
Expand All @@ -439,16 +512,20 @@ private def mkStructView (stx : Syntax) (structName : Name) (sources : SourcesVi
/- Recall that `stx` is of the form
```
leading_parser "{" >> optional (atomic (sepBy1 termParser ", " >> " with "))
>> structInstFields (sepByIndent (structInstFieldAbbrev <|> structInstField) ...)
>> structInstFields (sepByIndent structInstField ...)
>> optional ".."
>> optional (" : " >> termParser)
>> " }"
```
This method assumes that `structInstFieldAbbrev` had already been expanded.
This method assumes that `structInstField` had already been expanded by the macro `expandStructInstFields`
and is of the form
```
def structInstFieldDef := leading_parser
structInstLVal >> group (null >> null >> group (" := " >> termParser))
```
-/
let fields ← stx[2][0].getSepArgs.toList.mapM fun fieldStx => do
let val := fieldStx[2]
let val := fieldStx[1][2][1]
let first ← toFieldLHS fieldStx[0][0]
let rest ← fieldStx[0][1].getArgs.toList.mapM toFieldLHS
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field }
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Linter/UnusedVariables.lean
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ builtin_initialize addBuiltinUnusedVariablesIgnoreFn (fun _ stack opts =>
!getLinterUnusedVariablesFunArgs opts &&
stack.matches [`null, none, `null, ``Lean.Parser.Term.letIdDecl, none] &&
(stack.get? 3 |>.any fun (_, pos) => pos == 1) &&
(stack.get? 5 |>.any fun (stx, _) => !stx.isOfKind ``Lean.Parser.Command.whereStructField))
(stack.get? 5 |>.any fun (stx, _) => !stx.isOfKind ``Lean.Parser.Term.structInstField))

/--
Function argument in declaration signature (when `linter.unusedVariables.funArgs` is false)
Expand Down
4 changes: 1 addition & 3 deletions src/Lean/Parser/Command.lean
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,8 @@ def declValSimple := leading_parser
" :=" >> ppHardLineUnlessUngrouped >> declBody >> Termination.suffix >> optional Term.whereDecls
def declValEqns := leading_parser
Term.matchAltsWhereDecls
def whereStructField := leading_parser
Term.letDecl
def whereStructInst := leading_parser
ppIndent ppSpace >> "where" >> Term.structInstFields (sepByIndent (ppGroup whereStructField) "; " (allowTrailingSep := true)) >>
ppIndent ppSpace >> "where" >> Term.structInstFields (sepByIndent Term.structInstField "; " (allowTrailingSep := true)) >>
optional Term.whereDecls
/-- `declVal` matches the right-hand side of a declaration, one of:
* `:= expr` (a "simple declaration")
Expand Down
Loading

0 comments on commit c398f63

Please sign in to comment.