diff --git a/metaquot/ppxlib_metaquot.ml b/metaquot/ppxlib_metaquot.ml index 1de3eeff4..d4d3ae9e5 100644 --- a/metaquot/ppxlib_metaquot.ml +++ b/metaquot/ppxlib_metaquot.ml @@ -11,6 +11,19 @@ type quoted_attributes = { attributes are placed on, e.g. pexp_attributes. *) } +let coalesce_arity (input : expression) result = + match input with + | { pexp_desc = Pexp_function _; pexp_loc = loc; _ } -> + let ppxlib_coalesce_arity = + Ldot + ( Ldot (Ldot (Lident "Ppxlib", "Ast_builder"), "Default"), + "coalesce_arity" ) + in + pexp_apply ~loc + (pexp_ident ~loc { txt = ppxlib_coalesce_arity; loc }) + [ (Nolabel, result) ] + | _ -> result + module Make (M : sig type result @@ -33,6 +46,7 @@ module Make (M : sig val location : location -> result val location_stack : (location -> result) option val attributes : (location -> result) option + val coalesce : (expression -> result -> result) option class std_lifters : location -> [result] Ppxlib_traverse_builtins.std_lifters end) = @@ -68,8 +82,8 @@ struct | Some f -> f loc method! expression e = - match e.pexp_desc with - | Pexp_extension (({ txt = "e"; _ }, _) as ext) -> + match (M.coalesce, e.pexp_desc) with + | _, Pexp_extension (({ txt = "e"; _ }, _) as ext) -> let attributes = { quoted_attributes = e.pexp_attributes; @@ -77,7 +91,8 @@ struct } in M.cast self ext (Some attributes) ~type_name:"expression" - | _ -> super#expression e + | Some f, _ -> f e (super#expression e) + | None, _ -> super#expression e method! pattern p = match p.ppat_desc with @@ -149,6 +164,7 @@ module Expr = Make (struct let location loc = evar ~loc:{ loc with loc_ghost = true } "loc" let location_stack = None let attributes = None + let coalesce = Some coalesce_arity class std_lifters = Ppxlib_metaquot_lifters.expression_lifters @@ -228,6 +244,7 @@ module Patt = Make (struct Some (fun loc -> ppat_any ~loc:{ loc with loc_ghost = true }) let attributes = Some (fun loc -> ppat_any ~loc:{ loc with loc_ghost = true }) + let coalesce = None class std_lifters = Ppxlib_metaquot_lifters.pattern_lifters diff --git a/src/ast_builder.ml b/src/ast_builder.ml index 459270e8b..b1642a641 100644 --- a/src/ast_builder.ml +++ b/src/ast_builder.ml @@ -68,6 +68,22 @@ module Default = struct (*-------------------------------------------------------*) + let coalesce_arity e = + match e.pexp_desc with + (* We stop coalescing parameters if there is a constraint on the result of a function + (i.e [fun x y : T -> ...] or the body is a function_case. *) + | Pexp_function (_, Some _, _) | Pexp_function (_, _, Pfunction_cases _) -> + e + | Pexp_function + (params1, None, Pfunction_body ({ pexp_attributes = []; _ } as body1)) + -> ( + match body1.pexp_desc with + | Pexp_function (params2, constraint_, body2) -> + Latest.pexp_function ~loc:e.pexp_loc (params1 @ params2) constraint_ + body2 + | _ -> e) + | _ -> e + let pstr_value_list ~loc rec_flag = function | [] -> [] | vbs -> [ pstr_value ~loc rec_flag vbs ] diff --git a/src/ast_builder.mli b/src/ast_builder.mli index e2190eaca..42749f804 100644 --- a/src/ast_builder.mli +++ b/src/ast_builder.mli @@ -74,6 +74,14 @@ module Default : sig (** [pexp_function_cases] builds an expression in the shape [function C1 -> E1 | ...]. *) + val coalesce_arity : expression -> expression + (** [coalesce_arity e] will produce a maximum arity function from an + expression. + + For example, [fun x -> fun y -> x + y] becomes [fun x y -> x + y]. Since + OCaml 5.2, these two functions have a different {! Parsetree} + representation. *) + val constructor_declaration : loc:location -> name:label loc -> diff --git a/test/metaquot/test.ml b/test/metaquot/test.ml index 6e4ac5c02..b33049e81 100644 --- a/test/metaquot/test.ml +++ b/test/metaquot/test.ml @@ -602,3 +602,16 @@ Line _, characters 36-38: Error: This expression should not be a unit literal, the expected type is Ppxlib_ast.Ast.module_type |}] + +(* Coalescing arguments from [fun x -> fun y -> fun z -> ...] to + [fun x y z -> ...] *) +let _ = + let e = [%expr fun z -> x + y + z] in + let f = [%expr fun y -> [%e e]] in + let func = [%expr fun x -> [%e f]] in + Format.asprintf "%a" Pprintast.expression func + + +[%%expect{| +- : string = "fun x y z -> (x + y) + z" +|}] diff --git a/traverse/ppxlib_traverse.ml b/traverse/ppxlib_traverse.ml index 48e21caf7..eaba6c1dc 100644 --- a/traverse/ppxlib_traverse.ml +++ b/traverse/ppxlib_traverse.ml @@ -520,11 +520,26 @@ let gen_mapper ~(what : what) td = | None -> what#any ~loc | Some te -> type_expr_mapper ~what te) in - List.fold_right td.ptype_params ~init:body ~f:(fun (ty, _) acc -> - let loc = ty.ptyp_loc in - match ty.ptyp_desc with - | Ptyp_var s -> pexp_fun ~loc Nolabel None (pvar ~loc ("_" ^ s)) acc - | _ -> pexp_fun ~loc Nolabel None (ppat_any ~loc) acc) + let params = + List.map + ~f:(fun (ty, _) -> + let loc = ty.ptyp_loc in + let desc = + match ty.ptyp_desc with + | Ptyp_var s -> pvar ~loc ("_" ^ s) + | _ -> ppat_any ~loc + in + let pparam_desc = Pparam_val (Nolabel, None, desc) in + { pparam_loc = loc; pparam_desc }) + td.ptype_params + in + let pexp_desc = Pexp_function (params, None, Pfunction_body body) in + { + pexp_desc; + pexp_loc = td.ptype_loc; + pexp_loc_stack = []; + pexp_attributes = []; + } let type_deps = let collect =