Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: parity between structure instance notation and where notation #6165

Merged
merged 3 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/Init/Notation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ For example, `let x ← e` is a `doElem`, and a `do` block consists of a list of
def doElem : Category := {}

/-- `structInstFieldDecl` is the syntax category for value declarations for fields in structure instance notation.
For example, the `:= 1` and `where a := 3` in `{ x := 1, y where a := 3 }` are in the `structInstFieldDecl` class.
This category is necessary because structure instance notation is recursive due to the `x where ...` field notation. -/
For example, the `:= 1` and `| 0 => 0 | n + 1 => n` in `{ x := 1, f | 0 => 0 | n + 1 => n }` are in the `structInstFieldDecl` class. -/
def structInstFieldDecl : Category := {}

/-- `level` is a builtin syntax category for universe levels.
Expand Down
77 changes: 31 additions & 46 deletions src/Lean/Elab/MutualDef.lean
Original file line number Diff line number Diff line change
Expand Up @@ -282,52 +282,37 @@ private partial def withFunLocalDecls {α} (headers : Array DefViewElabHeader) (
k fvars
loop 0 #[]

private def expandWhereStructInst : Macro
| whereStx@`(Parser.Command.whereStructInst|where%$whereTk $[$decls:letDecl];* $[$whereDecls?:whereDecls]?) => do
let letIdDecls ← decls.mapM fun stx => match stx with
| `(letDecl|$_decl:letPatDecl) => Macro.throwErrorAt stx "patterns are not allowed here"
| `(letDecl|$decl:letEqnsDecl) => expandLetEqnsDecl decl (useExplicit := false)
| `(letDecl|$decl:letIdDecl) => pure decl
| _ => Macro.throwUnsupported
let structInstFields ← letIdDecls.mapM fun
| stx@`(letIdDecl|$id:ident $binders* $[: $ty?]? := $val) => withRef stx do
let mut val := val
if let some ty := ty? then
val ← `(($val : $ty))
-- HACK: this produces invalid syntax, but the fun elaborator supports letIdBinders as well
have : Coe (TSyntax ``letIdBinder) (TSyntax ``funBinder) := ⟨(⟨·⟩)⟩
val ← if binders.size > 0 then `(fun $binders* => $val) else pure val
`(structInstField|$id:ident := $val)
| stx@`(letIdDecl|_ $_* $[: $_]? := $_) => Macro.throwErrorAt stx "'_' is not allowed here"
| _ => Macro.throwUnsupported

let startOfStructureTkInfo : SourceInfo :=
match whereTk.getPos? with
| some pos => .synthetic pos ⟨pos.byteIdx + 1⟩ true
| none => .none
-- Position the closing `}` at the end of the trailing whitespace of `where $[$_:letDecl];*`.
-- We need an accurate range of the generated structure instance in the generated `TermInfo`
-- so that we can determine the expected type in structure field completion.
let structureStxTailInfo :=
whereStx[1].getTailInfo?
<|> whereStx[0].getTailInfo?
let endOfStructureTkInfo : SourceInfo :=
match structureStxTailInfo with
| some (SourceInfo.original _ _ trailing _) =>
let tokenPos := trailing.str.prev trailing.stopPos
let tokenEndPos := trailing.stopPos
.synthetic tokenPos tokenEndPos true
| _ => .none

let body ← `(structInst| { $structInstFields,* })
let body := body.raw.setInfo <|
match startOfStructureTkInfo.getPos?, endOfStructureTkInfo.getTailPos? with
| some startPos, some endPos => .synthetic startPos endPos true
| _, _ => .none
match whereDecls? with
| some whereDecls => expandWhereDecls whereDecls body
| none => return body
| _ => Macro.throwUnsupported
private def expandWhereStructInst : Macro := fun whereStx => do
let whereTk := whereStx[0]
let structInstFields : TSyntaxArray ``Parser.Term.structInstField := .mk whereStx[1][0].getSepArgs
let whereDecls? := whereStx[2].getOptional?

let startOfStructureTkInfo : SourceInfo :=
match whereTk.getPos? with
| some pos => .synthetic pos ⟨pos.byteIdx + 1⟩ true
| none => .none
-- Position the closing `}` at the end of the trailing whitespace of `where $[$_:letDecl];*`.
-- We need an accurate range of the generated structure instance in the generated `TermInfo`
-- so that we can determine the expected type in structure field completion.
let structureStxTailInfo :=
whereStx[1].getTailInfo?
<|> whereStx[0].getTailInfo?
let endOfStructureTkInfo : SourceInfo :=
match structureStxTailInfo with
| some (SourceInfo.original _ _ trailing _) =>
let tokenPos := trailing.str.prev trailing.stopPos
let tokenEndPos := trailing.stopPos
.synthetic tokenPos tokenEndPos true
| _ => .none

let body ← `(structInst| { $structInstFields,* })
let body := body.raw.setInfo <|
match startOfStructureTkInfo.getPos?, endOfStructureTkInfo.getTailPos? with
| some startPos, some endPos => .synthetic startPos endPos true
| _, _ => .none
match whereDecls? with
| some whereDecls => expandWhereDecls whereDecls body
| none => return body

/-
Recall that
Expand Down
15 changes: 10 additions & 5 deletions src/Lean/Elab/PatternVar.lean
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,16 @@ partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroSc
| `({ $[$srcs?,* with]? $fields,* $[..%$ell?]? $[: $ty?]? }) =>
if let some srcs := srcs? then
throwErrorAt (mkNullNode srcs) "invalid struct instance pattern, 'with' is not allowed in patterns"
let fields ← fields.getElems.mapM fun
| `(Parser.Term.structInstField| $lval:structInstLVal := $val) => do
let newVal ← collect val
`(Parser.Term.structInstField| $lval:structInstLVal := $newVal)
| _ => throwInvalidPattern -- `structInstFieldAbbrev` should be expanded at this point
-- TODO(kmill) restore this
-- let fields ← fields.getElems.mapM fun
-- | `(Parser.Term.structInstField| $lval:structInstLVal := $val) => do
-- let newVal ← collect val
-- `(Parser.Term.structInstField| $lval:structInstLVal := $newVal)
-- | _ => throwInvalidPattern -- `structInstFieldAbbrev` should be expanded at this point
let fields ← fields.getElems.mapM fun field => do
let field := field.raw
let val ← collect field[1][2][1]
pure <| field.setArg 1 <| field[1].setArg 2 <| field[1][2].setArg 1 val
`({ $[$srcs?,* with]? $fields,* $[..%$ell?]? $[: $ty?]? })
| _ => throwInvalidPattern

Expand Down
135 changes: 106 additions & 29 deletions src/Lean/Elab/StructInst.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,32 @@ open Meta
open TSyntax.Compat

/-!
Recall that structure instances are of the form:
```
"{" >> optional (atomic (sepBy1 termParser ", " >> " with "))
>> manyIndent (group ((structInstFieldAbbrev <|> structInstField) >> optional ", "))
Recall that structure instances are (after removing parsing and pretty printing hints):

```lean
def structInst := leading_parser
"{ " >> optional (sepBy1 termParser ", " >> " with ")
>> structInstFields (sepByIndent structInstField ", " (allowTrailingSep := true))
>> optEllipsis
>> optional (" : " >> termParser)
>> " }"
>> optional (" : " >> termParser) >> " }"

def structInstField := leading_parser
structInstLVal >> optional (many structInstFieldBinder >> optType >> structInstFieldDecl)

@[builtin_structInstFieldDecl_parser]
def structInstFieldDef := leading_parser
" := " >> termParser

@[builtin_structInstFieldDecl_parser]
def structInstFieldEqns := leading_parser
matchAlts

def structInstWhereBody := leading_parser
structInstFields (sepByIndent structInstField "; " (allowTrailingSep := true))

@[builtin_structInstFieldDecl_parser]
def structInstFieldWhere := leading_parser
"where" >> structInstWhereBody
```
-/

Expand All @@ -54,21 +73,74 @@ Structure instance notation makes use of the expected type.
let stxNew := stx.setArg 4 mkNullNode
`(($stxNew : $expected))

def mkStructInstField (lval : TSyntax ``Parser.Term.structInstLVal) (binders : TSyntaxArray ``Parser.Term.structInstFieldBinder)
(type? : Option Term) (val : Term) : MacroM Term := do
let mut val := val
if let some type := type? then
val ← `(($val : $type))
if !binders.isEmpty then
-- HACK: this produces invalid syntax, but the fun elaborator supports structInstFieldBinder as well
val ← `(fun $binders* => $val)
-- `(Parser.Term.structInstField| $lval := $val)
return mkNode ``Parser.Term.structInstField
#[lval, mkNullNode #[mkNullNode, mkNullNode, mkNode ``Parser.Term.structInstFieldDef #[mkAtom " := ", val]]]

/--
Expands field abbreviation notation.
Example: `{ x, y := 0 }` expands to `{ x := x, y := 0 }`.
Takes an arbitrary `structInstField` and expands it to be a `structInstFieldDef` without any binders or type ascription.
-/
@[builtin_macro Lean.Parser.Term.structInst] def expandStructInstFieldAbbrev : Macro
| `({ $[$srcs,* with]? $fields,* $[..%$ell]? $[: $ty]? }) =>
if fields.getElems.raw.any (·.getKind == ``Lean.Parser.Term.structInstFieldAbbrev) then do
let fieldsNew ← fields.getElems.mapM fun
| `(Parser.Term.structInstFieldAbbrev| $id:ident) =>
`(Parser.Term.structInstField| $id:ident := $id:ident)
| field => return field
`({ $[$srcs,* with]? $fieldsNew,* $[..%$ell]? $[: $ty]? })
private def expandStructInstField (stx : Syntax) : MacroM (Option Syntax) := withRef stx do
if stx.isOfKind `Lean.Parser.Term.structInstField && stx.getNumArgs == 3 then
-- old syntax
let lval : TSyntax ``Parser.Term.structInstLVal := stx[0]
let val : Term := stx[2]
mkStructInstField lval #[] none val
else if stx.isOfKind `Lean.Parser.Term.structInstFieldAbbrev then
-- old syntax
let id : Ident := stx[0]
let lval ← `(Parser.Term.structInstLVal| $id:ident)
mkStructInstField lval #[] none id
else if stx.isOfKind ``Parser.Term.structInstField then
let lval := stx[0]
if stx[1].getNumArgs > 0 then
let binders := stx[1][0].getArgs
let ty? := match stx[1][1] with | `(Parser.Term.optTypeForStructInst| $[: $ty?]?) => ty? | _ => none
let decl := stx[1][2]
match decl with
| `(Parser.Term.structInstFieldDef| := $val) =>
if binders.isEmpty && ty?.isNone then
return none
else
mkStructInstField lval binders ty? val
| `(Parser.Term.structInstFieldEqns| $alts:matchAlts) =>
let val ← expandMatchAltsIntoMatch stx alts (useExplicit := false)
mkStructInstField lval binders ty? val
| _ => Macro.throwUnsupported
else
Macro.throwUnsupported
| _ => Macro.throwUnsupported
-- Abbreviation
match lval with
| `(Parser.Term.structInstLVal| $id:ident) =>
mkStructInstField lval #[] none id
| _ =>
Macro.throwErrorAt lval "unsupported structure instance field abbreviation, expecting identifier"
else
Macro.throwUnsupported

/--
Expands fields.
* Abbrevations. Example: `{ x }` expands to `{ x := x }`.
* Equations. Example: `{ f | 0 => 0 | n + 1 => n }` expands to `{ f := fun x => match x with | 0 => 0 | n + 1 => n }`.
* `where`. Example: `{ s where x := 1 }` expands to `{ s := { x := 1 }}`.
* Binders and types. Example: `{ f n : Nat := n + 1 }` expands to `{ f := fun n => (n + 1 : Nat) }`.
-/
@[builtin_macro Lean.Parser.Term.structInst] def expandStructInstFields : Macro | stx => do
let structInstFields := stx[2]
let fields := structInstFields[0].getSepArgs
let fields? ← fields.mapM expandStructInstField
if fields?.all (·.isNone) then
Macro.throwUnsupported
let fields := fields?.zipWith fields Option.getD
let structInstFields := structInstFields.setArg 0 <| Syntax.mkSep fields (mkAtomFrom stx ", ")
return stx.setArg 2 structInstFields

/--
If `stx` is of the form `{ s₁, ..., sₙ with ... }` and `sᵢ` is not a local variable,
Expand Down Expand Up @@ -187,12 +259,13 @@ def structInstArrayRef := leading_parser "[" >> termParser >>"]"
-/
private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do
let s? ← stx[2][0].getSepArgs.foldlM (init := none) fun s? arg => do
/- arg is of the form `structInstFieldAbbrev <|> structInstField` -/
if arg.getKind == ``Lean.Parser.Term.structInstField then
/- Remark: the syntax for `structInstField` is
/- arg is of the form `structInstField`. It should be macro expanded at this point, but we make sure it's the case. -/
if arg[1][2].getKind == ``Lean.Parser.Term.structInstFieldDef then
/- Remark: the syntax for `structInstField` after macro expansion is
```
def structInstLVal := leading_parser (ident <|> numLit <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef)
def structInstField := leading_parser structInstLVal >> " := " >> termParser
def structInstFieldDef := leading_parser
structInstLVal >> group (null >> null >> group (" := " >> termParser))
```
-/
let lval := arg[0]
Expand Down Expand Up @@ -235,7 +308,7 @@ private def elabModifyOp (stx modifyOp : Syntax) (sources : Array ExplicitSource
withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
let rest := modifyOp[0][1]
if rest.isNone then
cont modifyOp[2]
cont modifyOp[1][2][1]
else
let s ← `(s)
let valFirst := rest[0]
Expand Down Expand Up @@ -412,7 +485,7 @@ Converts a `Field StructInstView` back into syntax. Used to construct synthetic
private def Field.toSyntax : Field → Syntax
| field =>
let stx := field.ref
let stx := stx.setArg 2 field.val.toSyntax
let stx := stx.setArg 1 <| stx[1].setArg 2 <| stx[1][2].setArg 1 field.val.toSyntax
match field.lhs with
| first::rest => stx.setArg 0 <| mkNullNode #[first.toSyntax true, mkNullNode <| rest.toArray.map (FieldLHS.toSyntax false) ]
| _ => unreachable!
Expand All @@ -428,7 +501,7 @@ private def toFieldLHS (stx : Syntax) : MacroM FieldLHS :=
return FieldLHS.fieldName stx stx.getId.eraseMacroScopes
else match stx.isFieldIdx? with
| some idx => return FieldLHS.fieldIndex stx idx
| none => Macro.throwError "unexpected structure syntax"
| none => Macro.throwErrorAt stx "unexpected structure syntax"

/--
Creates a structure instance view from structure instance notation
Expand All @@ -439,16 +512,20 @@ private def mkStructView (stx : Syntax) (structName : Name) (sources : SourcesVi
/- Recall that `stx` is of the form
```
leading_parser "{" >> optional (atomic (sepBy1 termParser ", " >> " with "))
>> structInstFields (sepByIndent (structInstFieldAbbrev <|> structInstField) ...)
>> structInstFields (sepByIndent structInstField ...)
>> optional ".."
>> optional (" : " >> termParser)
>> " }"
```

This method assumes that `structInstFieldAbbrev` had already been expanded.
This method assumes that `structInstField` had already been expanded by the macro `expandStructInstFields`
and is of the form
```
def structInstFieldDef := leading_parser
structInstLVal >> group (null >> null >> group (" := " >> termParser))
```
-/
let fields ← stx[2][0].getSepArgs.toList.mapM fun fieldStx => do
let val := fieldStx[2]
let val := fieldStx[1][2][1]
let first ← toFieldLHS fieldStx[0][0]
let rest ← fieldStx[0][1].getArgs.toList.mapM toFieldLHS
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field }
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Linter/UnusedVariables.lean
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ builtin_initialize addBuiltinUnusedVariablesIgnoreFn (fun _ stack opts =>
!getLinterUnusedVariablesFunArgs opts &&
stack.matches [`null, none, `null, ``Lean.Parser.Term.letIdDecl, none] &&
(stack.get? 3 |>.any fun (_, pos) => pos == 1) &&
(stack.get? 5 |>.any fun (stx, _) => !stx.isOfKind ``Lean.Parser.Command.whereStructField))
(stack.get? 5 |>.any fun (stx, _) => !stx.isOfKind ``Lean.Parser.Term.structInstField))

/--
Function argument in declaration signature (when `linter.unusedVariables.funArgs` is false)
Expand Down
4 changes: 1 addition & 3 deletions src/Lean/Parser/Command.lean
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,8 @@ def declValSimple := leading_parser
" :=" >> ppHardLineUnlessUngrouped >> declBody >> Termination.suffix >> optional Term.whereDecls
def declValEqns := leading_parser
Term.matchAltsWhereDecls
def whereStructField := leading_parser
Term.letDecl
def whereStructInst := leading_parser
ppIndent ppSpace >> "where" >> Term.structInstFields (sepByIndent (ppGroup whereStructField) "; " (allowTrailingSep := true)) >>
ppIndent ppSpace >> "where" >> Term.structInstFields (sepByIndent Term.structInstField "; " (allowTrailingSep := true)) >>
optional Term.whereDecls
/-- `declVal` matches the right-hand side of a declaration, one of:
* `:= expr` (a "simple declaration")
Expand Down
Loading
Loading