From 220fdb2c9c68ba93faefe724c1442b8787cb8b1c Mon Sep 17 00:00:00 2001 From: Patrick Ferris Date: Sun, 9 Jun 2024 11:18:31 +0100 Subject: [PATCH] Refactor driver's transformations To make the driver's handling of transformations clearer, we convert them into specific kinds of passes on the AST (e.g. distinguishing between linters and preprocessors). Signed-off-by: Patrick Ferris --- src/driver.ml | 355 +++++++++++++++++++++++++++++--------------------- 1 file changed, 207 insertions(+), 148 deletions(-) diff --git a/src/driver.ml b/src/driver.ml index 659df1d3..443830d5 100644 --- a/src/driver.ml +++ b/src/driver.ml @@ -114,52 +114,47 @@ module Instrument = struct end module Transform = struct - type t = { + (* A full rewrite of implementation OCaml code i.e. .ml files *) + type 'result full_impl_ast_rewrite_pass = + Expansion_context.Base.t -> Parsetree.structure -> 'result + + (* A full rewrite of interface OCaml code i.e. .mli files *) + type 'result full_intf_ast_rewrite_pass = + Expansion_context.Base.t -> Parsetree.signature -> 'result + + type 'result enclosing_header_footer = + Expansion_context.Base.t -> Location.t option -> 'result * 'result + + (* Meta data about the transformation like its name. *) + type meta = { name : string; aliases : string list; - impl : - (Expansion_context.Base.t -> - Parsetree.structure -> - Parsetree.structure With_errors.t) - option; - intf : - (Expansion_context.Base.t -> - Parsetree.signature -> - Parsetree.signature With_errors.t) - option; - lint_impl : - (Expansion_context.Base.t -> Parsetree.structure -> Lint_error.t list) - option; - lint_intf : - (Expansion_context.Base.t -> Parsetree.signature -> Lint_error.t list) - option; + registered_at : Caller_id.t; + } + + (* A Transform.t represents a raw transformation as registered + by a PPX to be applied to the AST. Most likely only a single + optional field will be filled with a corresponding transformation, + the rest will be [None]. *) + type t = { + meta : meta; + impl : Parsetree.structure With_errors.t full_impl_ast_rewrite_pass option; + intf : Parsetree.signature With_errors.t full_intf_ast_rewrite_pass option; + lint_impl : Lint_error.t list full_impl_ast_rewrite_pass option; + lint_intf : Lint_error.t list full_intf_ast_rewrite_pass option; preprocess_impl : - (Expansion_context.Base.t -> - Parsetree.structure -> - Parsetree.structure With_errors.t) - option; + Parsetree.structure With_errors.t full_impl_ast_rewrite_pass option; preprocess_intf : - (Expansion_context.Base.t -> - Parsetree.signature -> - Parsetree.signature With_errors.t) - option; - enclose_impl : - (Expansion_context.Base.t -> - Location.t option -> - Parsetree.structure * Parsetree.structure) - option; - enclose_intf : - (Expansion_context.Base.t -> - Location.t option -> - Parsetree.signature * Parsetree.signature) - option; + Parsetree.signature With_errors.t full_intf_ast_rewrite_pass option; + enclose_impl : Parsetree.structure enclosing_header_footer option; + enclose_intf : Parsetree.signature enclosing_header_footer option; instrument : Instrument.t option; rules : Context_free.Rule.t list; - registered_at : Caller_id.t; } let has_name t name = - String.equal name t.name || List.exists ~f:(String.equal name) t.aliases + String.equal name t.meta.name + || List.exists ~f:(String.equal name) t.meta.aliases let all : t list ref = ref [] @@ -179,7 +174,7 @@ module Transform = struct Printf.eprintf "Warning: code transformation %s registered twice.\n" name; Printf.eprintf " - first time was at %a\n" print_caller_id - ct.registered_at; + ct.meta.registered_at; Printf.eprintf " - second time is at %a\n" print_caller_id caller_id); let impl = Option.map impl ~f:(fun f ctx ast -> return (f ctx ast)) in let intf = Option.map intf ~f:(fun f ctx ast -> return (f ctx ast)) in @@ -189,10 +184,10 @@ module Transform = struct let preprocess_intf = Option.map preprocess_intf ~f:(fun f ctx ast -> return (f ctx ast)) in + let meta = { name; aliases; registered_at = caller_id } in let ct = { - name; - aliases; + meta; rules; enclose_impl; enclose_intf; @@ -203,11 +198,54 @@ module Transform = struct preprocess_intf; lint_intf; instrument; - registered_at = caller_id; } in all := ct :: !all + (* An individual pass is of a particular kind (e.g. lint, preprocess etc.) + and may be for the implementation, the interface or both *) + type ('impl, 'intf) pass = { + meta : meta; + pass : [ `Impl of 'impl | `Intf of 'intf | `Both of 'impl * 'intf ]; + } + + let get_pass_impl = function + | { pass = `Impl i | `Both (i, _); _ } -> Some i + | _ -> None + + let get_pass_intf = function + | { pass = `Intf i | `Both (_, i); _ } -> Some i + | _ -> None + + (* We distinguish between different kinds of transformations: + - Linters: Are transformations applied to unprocessed files. + - Preprocessors: The main preprocessing transformations. + - Generic: Essentially any transformation that is not a linter + or preprocessor is a generic transformation. + - Enclosers: Enclosers wrap implementations and interfaces with + headers and footers. They eventually are mapped into generic + transformations. + - Instrumentation: these are mapped into generic transformations. + *) + type transform = + [ `Generic of + ( Parsetree.structure With_errors.t full_impl_ast_rewrite_pass, + Parsetree.signature With_errors.t full_intf_ast_rewrite_pass ) + pass + | `Linter of + ( Lint_error.t list full_impl_ast_rewrite_pass, + Lint_error.t list full_intf_ast_rewrite_pass ) + pass + | `Preprocessor of + ( Parsetree.structure With_errors.t full_impl_ast_rewrite_pass, + Parsetree.signature With_errors.t full_intf_ast_rewrite_pass ) + pass ] + + let get_transform_meta : transform -> meta = function + | `Generic { meta; _ } -> meta + | `Linter { meta; _ } -> meta + | `Preprocessor { meta; _ } -> meta + let rec last prev l = match l with [] -> prev | x :: l -> last x l let loc_of_list ~get_loc l = @@ -302,14 +340,20 @@ module Transform = struct map#signature base_ctxt (List.concat [ attrs; header; sg; footer ]) >>= fun sg -> match intf with None -> return sg | Some f -> f ctxt sg in - { t with impl = Some map_impl; intf = Some map_intf } + { meta = t.meta; pass = `Both (map_impl, map_intf) } let builtin_of_context_free_rewriters ~hook ~rules ~enclose_impl ~enclose_intf ~input_name = - merge_into_generic_mappers ~hook ~input_name + let meta = { name = ""; aliases = []; + registered_at = Caller_id.get ~skip:[]; + } + in + merge_into_generic_mappers ~hook ~input_name + { + meta; impl = None; intf = None; lint_impl = None; @@ -320,7 +364,6 @@ module Transform = struct enclose_intf; instrument = None; rules; - registered_at = Caller_id.get ~skip:[]; } let partition_transformations ts = @@ -340,58 +383,45 @@ module Transform = struct in match Option.map t.instrument ~f with | Some (Before, transf) -> - ( { reduced_t with impl = Some transf; rules = [] } :: bef_i, - aft_i, - reduced_t :: rest ) + let before = + `Generic { meta = reduced_t.meta; pass = `Impl transf } + in + (before :: bef_i, aft_i, reduced_t :: rest) | Some (After, transf) -> - ( bef_i, - { reduced_t with impl = Some transf; rules = [] } :: aft_i, - reduced_t :: rest ) + let after = + `Generic { meta = reduced_t.meta; pass = `Impl transf } + in + (bef_i, after :: aft_i, reduced_t :: rest) | None -> (bef_i, aft_i, reduced_t :: rest)) in - ( `Linters - (List.filter_map ts ~f:(fun t -> - if Option.is_some t.lint_impl || Option.is_some t.lint_intf then - Some - { - name = Printf.sprintf "" t.name; - aliases = []; - impl = None; - intf = None; - lint_impl = t.lint_impl; - lint_intf = t.lint_intf; - enclose_impl = None; - enclose_intf = None; - preprocess_impl = None; - preprocess_intf = None; - instrument = None; - rules = []; - registered_at = t.registered_at; - } - else None)), - `Preprocess - (List.filter_map ts ~f:(fun t -> - if - Option.is_some t.preprocess_impl - || Option.is_some t.preprocess_intf - then - Some - { - name = Printf.sprintf "" t.name; - aliases = []; - impl = t.preprocess_impl; - intf = t.preprocess_intf; - lint_impl = None; - lint_intf = None; - enclose_impl = None; - enclose_intf = None; - preprocess_impl = None; - preprocess_intf = None; - instrument = None; - rules = []; - registered_at = t.registered_at; - } - else None)), + ( List.filter_map ts ~f:(fun t -> + let meta = + { + t.meta with + name = Printf.sprintf "" t.meta.name; + aliases = []; + } + in + match (t.lint_impl, t.lint_intf) with + | Some impl, Some intf -> + Some (`Linter { meta; pass = `Both (impl, intf) }) + | Some impl, None -> Some (`Linter { meta; pass = `Impl impl }) + | None, Some intf -> Some (`Linter { meta; pass = `Intf intf }) + | _ -> None), + List.filter_map ts ~f:(fun t -> + let meta = + { + t.meta with + name = Printf.sprintf "" t.meta.name; + aliases = []; + } + in + match (t.preprocess_impl, t.preprocess_intf) with + | Some impl, Some intf -> + Some (`Preprocessor { meta; pass = `Both (impl, intf) }) + | Some impl, None -> Some (`Preprocessor { meta; pass = `Impl impl }) + | None, Some intf -> Some (`Preprocessor { meta; pass = `Intf intf }) + | _ -> None), `Before_instrs before_instrs, `After_instrs after_instrs, `Rest rest ) @@ -456,7 +486,7 @@ let debug_dropped_attribute name ~old_dropped ~new_dropped = print_diff "reappeared" old_dropped new_dropped let get_whole_ast_passes ~embed_errors ~hook ~expect_mismatch_handler ~tool_name - ~input_name = + ~input_name : Transform.transform list = let cts = match !apply_list with | None -> List.rev !Transform.all @@ -465,8 +495,8 @@ let get_whole_ast_passes ~embed_errors ~hook ~expect_mismatch_handler ~tool_name List.find !Transform.all ~f:(fun (ct : Transform.t) -> Transform.has_name ct name)) in - let ( `Linters linters, - `Preprocess preprocess, + let ( linters, + preprocess, `Before_instrs before_instrs, `After_instrs after_instrs, `Rest cts ) = @@ -475,7 +505,8 @@ let get_whole_ast_passes ~embed_errors ~hook ~expect_mismatch_handler ~tool_name (* Allow only one preprocessor to assure deterministic order *) (if List.length preprocess > 1 then let pp = - String.concat ~sep:", " (List.map preprocess ~f:(fun t -> t.name)) + String.concat ~sep:", " + (List.map preprocess ~f:(fun (`Preprocessor t) -> t.meta.name)) in let err = Printf.sprintf "At most one preprocessor is allowed, while got: %s" pp @@ -483,49 +514,70 @@ let get_whole_ast_passes ~embed_errors ~hook ~expect_mismatch_handler ~tool_name failwith err); let make_generic transforms = if !no_merge then - List.map transforms - ~f: - (Transform.merge_into_generic_mappers ~embed_errors ~hook ~tool_name - ~expect_mismatch_handler ~input_name) + List.map transforms ~f:(fun v -> + let t = + Transform.merge_into_generic_mappers ~embed_errors ~hook ~tool_name + ~expect_mismatch_handler ~input_name v + in + `Generic t) else - (let get_enclosers ~f = - List.filter_map transforms ~f:(fun (ct : Transform.t) -> - match f ct with None -> None | Some x -> Some (ct.name, x)) - (* Sort them to ensure deterministic ordering *) - |> List.sort ~cmp:(fun (a, _) (b, _) -> String.compare a b) - |> List.map ~f:snd - in - - let rules = - List.map transforms ~f:(fun (ct : Transform.t) -> ct.rules) - |> List.concat - and impl_enclosers = get_enclosers ~f:(fun ct -> ct.enclose_impl) - and intf_enclosers = get_enclosers ~f:(fun ct -> ct.enclose_intf) in - match (rules, impl_enclosers, intf_enclosers) with - | [], [], [] -> transforms - | _ -> - let merge_encloser = function - | [] -> None - | enclosers -> - Some - (fun ctxt loc -> - let headers, footers = - List.map enclosers ~f:(fun f -> f ctxt loc) |> List.split - in - let headers = List.concat headers in - let footers = List.concat (List.rev footers) in - (headers, footers)) - in - Transform.builtin_of_context_free_rewriters ~rules ~embed_errors - ~hook ~expect_mismatch_handler - ~enclose_impl:(merge_encloser impl_enclosers) - ~enclose_intf:(merge_encloser intf_enclosers) - ~tool_name ~input_name - :: transforms) - |> List.filter ~f:(fun (ct : Transform.t) -> - match (ct.impl, ct.intf) with None, None -> false | _ -> true) + (* We merge all context-free rewriters, this also includes enclosers. *) + let ctx_free_pass, transforms = + let get_enclosers ~f = + List.filter_map transforms ~f:(fun (ct : Transform.t) -> + match f ct with None -> None | Some x -> Some (ct.meta.name, x)) + (* Sort them to ensure deterministic ordering *) + |> List.sort ~cmp:(fun (a, _) (b, _) -> String.compare a b) + |> List.map ~f:snd + in + + let rules = + List.map transforms ~f:(fun (ct : Transform.t) -> ct.rules) + |> List.concat + and impl_enclosers = get_enclosers ~f:(fun ct -> ct.enclose_impl) + and intf_enclosers = get_enclosers ~f:(fun ct -> ct.enclose_intf) in + match (rules, impl_enclosers, intf_enclosers) with + | [], [], [] -> (None, transforms) + | _ -> + let merge_encloser = function + | [] -> None + | enclosers -> + Some + (fun ctxt loc -> + let headers, footers = + List.map enclosers ~f:(fun f -> f ctxt loc) + |> List.split + in + let headers = List.concat headers in + let footers = List.concat (List.rev footers) in + (headers, footers)) + in + let pass = + Transform.builtin_of_context_free_rewriters ~rules ~embed_errors + ~hook ~expect_mismatch_handler + ~enclose_impl:(merge_encloser impl_enclosers) + ~enclose_intf:(merge_encloser intf_enclosers) + ~tool_name ~input_name + in + (Some (`Generic pass), transforms) + in + let generic_transforms = + List.filter_map + ~f:(fun (ct : Transform.t) -> + match (ct.impl, ct.intf) with + | None, None -> None + | Some impl, None -> + Some (`Generic { Transform.meta = ct.meta; pass = `Impl impl }) + | None, Some intf -> + Some (`Generic { meta = ct.meta; pass = `Intf intf }) + | Some impl, Some intf -> + Some (`Generic { meta = ct.meta; pass = `Both (impl, intf) })) + transforms + in + Option.to_list ctx_free_pass @ generic_transforms in - linters @ preprocess @ before_instrs @ make_generic cts @ after_instrs + let generics = make_generic cts in + linters @ preprocess @ before_instrs @ generics @ after_instrs let apply_transforms ~tool_name ~file_path ~field ~lint_field ~dropped_so_far ~hook ~expect_mismatch_handler ~input_name ~embed_errors ast = @@ -542,7 +594,9 @@ let apply_transforms ~tool_name ~file_path ~field ~lint_field ~dropped_so_far let acc = List.fold_left cts ~init:(ast, [], [], []) ~f:(fun - (ast, dropped, (lint_errors : _ list), errors) (ct : Transform.t) -> + (ast, dropped, (lint_errors : _ list), errors) + (ct : Transform.transform) + -> let input_name = match input_name with | Some input_name -> input_name @@ -571,8 +625,8 @@ let apply_transforms ~tool_name ~file_path ~field ~lint_field ~dropped_so_far let dropped = if !debug_attribute_drop then ( let new_dropped = dropped_so_far ast in - debug_dropped_attribute ct.name ~old_dropped:dropped - ~new_dropped; + let name = (Transform.get_transform_meta ct).name in + debug_dropped_attribute name ~old_dropped:dropped ~new_dropped; new_dropped) else [] in @@ -619,7 +673,8 @@ let print_passes () = in if !perform_checks then Printf.printf "\n"; - List.iter cts ~f:(fun ct -> Printf.printf "%s\n" ct.Transform.name); + List.iter cts ~f:(fun ct -> + Printf.printf "%s\n" (Transform.get_transform_meta ct).name); if !perform_checks then ( Printf.printf "\n"; if !perform_checks_on_extensions then @@ -690,8 +745,10 @@ let map_structure_gen st ~tool_name ~hook ~expect_mismatch_handler ~input_name let file_path = get_default_path_str st in let st, lint_errors, errors = apply_transforms st ~tool_name ~file_path - ~field:(fun (ct : Transform.t) -> ct.impl) - ~lint_field:(fun (ct : Transform.t) -> ct.lint_impl) + ~field:(function + | `Generic pass -> Transform.get_pass_impl pass | _ -> None) + ~lint_field:(function + | `Linter pass -> Transform.get_pass_impl pass | _ -> None) ~dropped_so_far:Attribute.dropped_so_far_structure ~hook ~expect_mismatch_handler ~input_name ~embed_errors in @@ -766,8 +823,10 @@ let map_signature_gen sg ~tool_name ~hook ~expect_mismatch_handler ~input_name let file_path = get_default_path_sig sg in let sg, lint_errors, errors = apply_transforms sg ~tool_name ~file_path - ~field:(fun (ct : Transform.t) -> ct.intf) - ~lint_field:(fun (ct : Transform.t) -> ct.lint_intf) + ~field:(function + | `Generic pass -> Transform.get_pass_intf pass | _ -> None) + ~lint_field:(function + | `Linter pass -> Transform.get_pass_intf pass | _ -> None) ~dropped_so_far:Attribute.dropped_so_far_signature ~hook ~expect_mismatch_handler ~input_name ~embed_errors in @@ -1209,7 +1268,7 @@ let set_output_mode mode = let print_transformations () = List.iter !Transform.all ~f:(fun (ct : Transform.t) -> - Printf.printf "%s\n" ct.name) + Printf.printf "%s\n" ct.meta.name) let parse_apply_list s = let names = @@ -1260,7 +1319,7 @@ let interpret_mask () = | Some names -> is_candidate && not (List.exists names ~f:(Transform.has_name ct)) in - if is_selected then Some ct.name else None + if is_selected then Some ct.meta.name else None in apply_list := Some (List.filter_map !Transform.all ~f:selected_transform_name)