From 4430330c6f95106b68ca157f438c72202d1ca54c Mon Sep 17 00:00:00 2001 From: Patrick Ferris Date: Sat, 14 Sep 2024 10:27:31 +0100 Subject: [PATCH] Make pexp_fun collect its arguments to form max arity functions Signed-off-by: Patrick Ferris --- src/ast_builder.ml | 29 ++++++++++++++++++----------- src/ast_builder.mli | 19 ++++++------------- traverse/ppxlib_traverse.ml | 25 +++++-------------------- 3 files changed, 29 insertions(+), 44 deletions(-) diff --git a/src/ast_builder.ml b/src/ast_builder.ml index b1642a641..2ea853753 100644 --- a/src/ast_builder.ml +++ b/src/ast_builder.ml @@ -33,16 +33,6 @@ module Default = struct ppat_desc = Ppat_construct (lid, Option.map p ~f:(fun p -> ([], p))); } - let pexp_fun ~loc (label : arg_label) expr p e = - let pparam_desc = Pparam_val (label, expr, p) in - let case = { pparam_desc; pparam_loc = loc } in - { - pexp_loc_stack = []; - pexp_attributes = []; - pexp_loc = loc; - pexp_desc = Pexp_function ([ case ], None, Pfunction_body e); - } - let pexp_function_cases ~loc cases = { pexp_loc_stack = []; @@ -51,7 +41,24 @@ module Default = struct pexp_desc = Pexp_function ([], None, Pfunction_cases (cases, loc, [])); } - let pexp_function ~loc cases = pexp_function_cases ~loc cases + (* let pexp_function ~loc cases = pexp_function_cases ~loc cases *) + + let add_fun_params return_constraint ~loc params body = + match params with + | [] -> body + | _ -> ( + match body.pexp_desc with + | Pexp_function (more_params, constraint_, func_body) -> + pexp_function ~loc (params @ more_params) constraint_ func_body + | _ -> + assert (match params with [] -> false | _ -> true); + pexp_function ~loc params return_constraint (Pfunction_body body)) + + let pexp_fun ~loc (label : arg_label) expr p e = + let param : function_param = + { pparam_desc = Pparam_val (label, expr, p); pparam_loc = loc } + in + add_fun_params ~loc None [ param ] e let value_binding ~loc ~pat ~expr = value_binding ~loc ~pat ~expr ~constraint_:None diff --git a/src/ast_builder.mli b/src/ast_builder.mli index 42749f804..ea33f819d 100644 --- a/src/ast_builder.mli +++ b/src/ast_builder.mli @@ -35,13 +35,6 @@ module Default : sig (label loc list * pattern) option -> pattern - val pexp_function : - loc:location -> - function_param list -> - type_constraint option -> - function_body -> - expression - val value_binding : ?constraint_:value_constraint -> loc:location -> @@ -63,12 +56,12 @@ module Default : sig val ppat_construct : loc:location -> longident loc -> pattern option -> pattern - val pexp_function : loc:location -> Import.cases -> expression - [@@ocaml.deprecated "use pexp_function_cases instead."] - (** @deprecated - This function will be used to construct a {! Parsetree.Pexp_function } - in the next release, to retain its current functionality migrate to - {! pexp_function_cases}. *) + val pexp_function : + loc:location -> + function_param list -> + type_constraint option -> + function_body -> + expression val pexp_function_cases : loc:location -> Import.cases -> expression (** [pexp_function_cases] builds an expression in the shape diff --git a/traverse/ppxlib_traverse.ml b/traverse/ppxlib_traverse.ml index eaba6c1dc..48e21caf7 100644 --- a/traverse/ppxlib_traverse.ml +++ b/traverse/ppxlib_traverse.ml @@ -520,26 +520,11 @@ let gen_mapper ~(what : what) td = | None -> what#any ~loc | Some te -> type_expr_mapper ~what te) in - let params = - List.map - ~f:(fun (ty, _) -> - let loc = ty.ptyp_loc in - let desc = - match ty.ptyp_desc with - | Ptyp_var s -> pvar ~loc ("_" ^ s) - | _ -> ppat_any ~loc - in - let pparam_desc = Pparam_val (Nolabel, None, desc) in - { pparam_loc = loc; pparam_desc }) - td.ptype_params - in - let pexp_desc = Pexp_function (params, None, Pfunction_body body) in - { - pexp_desc; - pexp_loc = td.ptype_loc; - pexp_loc_stack = []; - pexp_attributes = []; - } + List.fold_right td.ptype_params ~init:body ~f:(fun (ty, _) acc -> + let loc = ty.ptyp_loc in + match ty.ptyp_desc with + | Ptyp_var s -> pexp_fun ~loc Nolabel None (pvar ~loc ("_" ^ s)) acc + | _ -> pexp_fun ~loc Nolabel None (ppat_any ~loc) acc) let type_deps = let collect =