diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index 3ecd61a814bf..3d7636f233a7 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -1414,31 +1414,19 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L if isStructure env structName then if let some baseStructName := findField? env structName (Name.mkSimple fieldName) then return LValResolution.projFn baseStructName structName (Name.mkSimple fieldName) - -- Search the local context first - let fullName := Name.mkStr (privateToUserName structName) fieldName - for localDecl in (← getLCtx) do - if localDecl.isAuxDecl then - if let some localDeclFullName := (← getLCtx).auxDeclToFullName.get? localDecl.fvarId then - if fullName == privateToUserName localDeclFullName then - /- LVal notation is being used to make a "local" recursive call. -/ - return LValResolution.localRec structName localDecl.toExpr - -- Then search the environment - if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then - return LValResolution.const baseStructName structName fullName - throwInvalidFieldAt ref fieldName fullName - -- Suggest a potential unreachable private name as hint. This does not cover structure - -- inheritance, nor `import all`. - (declHint := (mkPrivateName env structName).mkStr fieldName) + resolveFieldName structName fieldName fun fullName => do + throwInvalidFieldAt ref fieldName fullName + -- Suggest a potential unreachable private name as hint. This does not cover structure + -- inheritance, nor `import all`. + (declHint := (mkPrivateName env structName).mkStr fieldName) | .forallE .., LVal.fieldName ref fieldName suffix? fullRef => - let fullName := Name.str `Function fieldName - if (← getEnv).contains fullName then - return LValResolution.const `Function `Function fullName - match e.getAppFn, suffix? with - | Expr.const c _, some suffix => - throwUnknownNameWithSuggestions (idOrConst := "constant") (ref? := fullRef) (c ++ suffix) - | _, _ => - throwInvalidFieldAt ref fieldName fullName + resolveFieldName `Function fieldName fun fullName => do + match e.getAppFn, suffix? with + | Expr.const c _, some suffix => + throwUnknownNameWithSuggestions (idOrConst := "constant") (ref? := fullRef) (c ++ suffix) + | _, _ => + throwInvalidFieldAt ref fieldName fullName | .forallE .., .fieldIdx .. => throwError "Invalid projection: Projections cannot be used on functions, and{indentExpr e}\n\ has function type{inlineExprTrailing eType}" @@ -1469,6 +1457,21 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L the necessary form." where + resolveFieldName (structName : Name) (fieldName : String) (errk : Name → TermElabM LValResolution) : TermElabM LValResolution := do + let fullName := Name.mkStr (privateToUserName structName) fieldName + -- Search the local context first + for localDecl in (← getLCtx) do + if localDecl.isAuxDecl then + if let some localDeclFullName := (← getLCtx).auxDeclToFullName.get? localDecl.fvarId then + if fullName == privateToUserName localDeclFullName then + /- LVal notation is being used to make a "local" recursive call. -/ + return LValResolution.localRec structName localDecl.toExpr + -- Then search the environment + if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then + return LValResolution.const baseStructName structName fullName + else + errk fullName + throwInvalidFieldAt {α : Type} (ref : Syntax) (fieldName : String) (fullName : Name) (declHint := Name.anonymous) : TermElabM α := do let msg ← @@ -1561,10 +1564,10 @@ private partial def mkBaseProjections (baseStructName : Name) (structName : Name e ← elabAppArgs projFn #[{ name := `self, val := Arg.expr e, suppressDeps := true }] (args := #[]) (expectedType? := none) (explicit := false) (ellipsis := false) return e -private partial def typeMatchesBaseName (type : Expr) (baseName : Name) : MetaM Bool := +private partial def typeMatchesBaseName (bi : BinderInfo) (type : Expr) (baseName : Name) : MetaM Bool := withReducibleAndInstances do if baseName == `Function then - return (← whnf type).isForall + return bi.isExplicit && (← whnf type).isForall else if type.cleanupAnnotations.isAppOf baseName then return true else @@ -1573,7 +1576,7 @@ private partial def typeMatchesBaseName (type : Expr) (baseName : Name) : MetaM return true else match ← unfoldDefinition? type with - | some type' => typeMatchesBaseName type' baseName + | some type' => typeMatchesBaseName bi type' baseName | none => return false /-- @@ -1610,7 +1613,7 @@ where /- If there is named argument with name `xDecl.userName`, then it is accounted for and we can't make use of it. -/ remainingNamedArgs := remainingNamedArgs.eraseIdx idx else - if ← typeMatchesBaseName xDecl.type baseName then + if ← typeMatchesBaseName bInfo xDecl.type baseName then /- We found a type of the form (baseName ...), or we found the first explicit argument in useFirstExplicit mode. First, we check if the current argument is one that can be used positionally, and if the current explicit position "fits" at `args` (i.e., it must be ≤ arg.size) -/ diff --git a/tests/elab/dotNotationFunction.lean b/tests/elab/dotNotationFunction.lean new file mode 100644 index 000000000000..1fb90a7c3f48 --- /dev/null +++ b/tests/elab/dotNotationFunction.lean @@ -0,0 +1,54 @@ +/-! +# Generalized field notation for functions + +Functions use the `Function` namespace, and they always take the first explicit argument that's a function. +-/ + +/-! +Motivating example for why it should only match explicit arguments. This `swap` function +is unusable using field notation in the intended way if it matched implicit arguments too. +https://github.com/leanprover/lean4/issues/1629 +-/ +def Function.swap {α β} {γ : α → β → Sort _} (f : (a : α) → (b : β) → γ a b) + (b : β) (a : α) : γ a b := f a b + +def mul : Nat → Nat → Nat := (· * ·) +/-- info: Function.swap mul : Nat → Nat → Nat -/ +#guard_msgs in #check Function.swap mul -- works +/-- info: Function.swap mul : Nat → Nat → Nat -/ +#guard_msgs in #check mul.swap + +example : mul.swap = Function.swap mul := rfl + +/-! +Function field notation can be `open`ed into other namespaces. +https://leanprover.zulipchat.com/#narrow/channel/113489-new-members/topic/generalized.20field.20notation.20vs.20namespaces/near/582689850 +-/ +def MyNS.Function.apply {α} (a : α) {β : α → Sort _} (f : (x : α) → β x) : β a := f a + +/-- error: Unknown constant `mul.apply` -/ +#guard_msgs in #check mul.apply 2 3 + +/-- info: Function.apply 2 mul 3 : Nat -/ +#guard_msgs in open MyNS in #check mul.apply 2 3 + +/-! +Function field notation can be used in recursive definitions. +-/ +def Function.iterate {α} (f : α → α) (n : Nat) (x : α) : α := + match n with + | 0 => x + | n+1 => f.iterate n (f x) + +/-- info: 1024 -/ +#guard_msgs in #eval (·*2).iterate 10 1 + +/-! +Another example of a definition that is reasonable to use with field notation. This is from Mathlib. +-/ +def Function.update {α : Sort u} {β : α → Sort v} [DecidableEq α] + (f : ∀ a, β a) (a' : α) (v : β a') (a : α) : β a := + if h : a = a' then Eq.ndrec v h.symm else f a + +/-- info: 108 -/ +#guard_msgs in #eval (mul.update 2 (· + 100)) 2 8