diff --git a/src/sail_lean_backend/pretty_print_lean.ml b/src/sail_lean_backend/pretty_print_lean.ml index 2d3dd7312..21275d883 100644 --- a/src/sail_lean_backend/pretty_print_lean.ml +++ b/src/sail_lean_backend/pretty_print_lean.ml @@ -21,13 +21,7 @@ type context = { kid_id_renames_rev : kid Bindings.t; (** Inverse of the [kid_id_renames] mapping. *) } -let initial_context env = - { - global = { effect_info = Effects.empty_side_effect_info }; - env; - kid_id_renames = KBindings.empty; - kid_id_renames_rev = Bindings.empty; - } +let initial_context env global = { global; env; kid_id_renames = KBindings.empty; kid_id_renames_rev = Bindings.empty } let add_single_kid_id_rename ctx id kid = let kir = @@ -379,7 +373,7 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) = begin match pat with | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> string "" - | _ -> flow (break 1) [string "let"; e0; string "←"] ^^ space + | _ -> flow (break 1) [string "let"; e0; string ":="] ^^ space end in nest 2 (e0_pp ^^ e1_pp) ^^ hardline ^^ e2_pp @@ -390,12 +384,15 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) = in let d_args = List.map d_of_arg args in let fn_monadic = not (Effects.function_is_pure f ctx.global.effect_info) in - nest 2 (wrap_with_pure (as_monadic && fn_monadic) (parens (flow (break 1) (d_id :: d_args)))) + nest 2 + (wrap_with_left_arrow ((not as_monadic) && fn_monadic) + (wrap_with_pure (as_monadic && not fn_monadic) (parens (flow (break 1) (d_id :: d_args)))) + ) | E_vector vals -> string "#v" ^^ wrap_with_pure as_monadic (brackets (nest 2 (flow (comma ^^ break 1) (List.map d_of_arg vals)))) | E_typ (typ, e) -> if effectful (effect_of e) then - parens (separate space [doc_exp false ctx e; colon; string "SailM"; doc_typ ctx typ]) + parens (separate space [doc_exp as_monadic ctx e; colon; string "SailM"; doc_typ ctx typ]) else wrap_with_pure as_monadic (parens (separate space [doc_exp false ctx e; colon; doc_typ ctx typ])) | E_tuple es -> wrap_with_pure as_monadic (parens (separate_map (comma ^^ space) d_of_arg es)) | E_let (LB_aux (LB_val (lpat, lexp), _), e) -> @@ -455,7 +452,7 @@ let doc_binder ctx i t = let ctx = match captured_typ_var (i, t) with Some (i, ki) -> add_single_kid_id_rename ctx i ki | _ -> ctx in (ctx, separate space [string (string_of_id i); colon; doc_typ ctx t] |> paranthesizer) -let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) = +let doc_funcl_init global (FCL_aux (FCL_funcl (id, pexp), annot)) = let env = env_of_tannot (snd annot) in let TypQ_aux (tq, l), typ = Env.get_val_spec_orig id env in let arg_typs, ret_typ, _ = @@ -474,7 +471,7 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) = | _ -> failwith "Argument pattern not translatable yet." ) in - let ctx = initial_context env in + let ctx = initial_context env global in let ctx, binders = List.fold_left (fun (ctx, bs) (i, t) -> @@ -503,7 +500,7 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) = fixup_binders ) -let doc_funcl_body fixup_binders (FCL_aux (FCL_funcl (id, pexp), annot)) = +let doc_funcl_body fixup_binders global (FCL_aux (FCL_funcl (id, pexp), annot)) = let env = env_of_tannot (snd annot) in let ctx = initial_context env in let _, _, exp, _ = destruct_pexp pexp in @@ -511,11 +508,11 @@ let doc_funcl_body fixup_binders (FCL_aux (FCL_funcl (id, pexp), annot)) = this adds a let binding at the beginning of the function, of the form [let x := (arg0, arg1)] *) let exp = fixup_binders exp in let is_monadic = effectful (effect_of exp) in - doc_exp is_monadic (initial_context env) exp + doc_exp is_monadic (initial_context env global) exp let doc_funcl ctx funcl = - let comment, signature, env, fixup_binders = doc_funcl_init funcl in - comment ^^ nest 2 (signature ^^ hardline ^^ doc_funcl_body fixup_binders funcl) + let comment, signature, env, fixup_binders = doc_funcl_init ctx.global funcl in + comment ^^ nest 2 (signature ^^ hardline ^^ doc_funcl_body fixup_binders ctx.global funcl) let doc_fundef ctx (FD_aux (FD_function (r, typa, fcls), fannot)) = match fcls with @@ -642,8 +639,8 @@ let inhabit_enum ctx typ_map = ) typ_map -let doc_reg_info env registers = - let ctx = initial_context env in +let doc_reg_info env global registers = + let ctx = initial_context env global in let type_map = List.fold_left add_reg_typ Bindings.empty registers in let type_map = Bindings.bindings type_map in @@ -660,10 +657,11 @@ let doc_reg_info env registers = empty; ] -let pp_ast_lean (env : Type_check.env) ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o = +let pp_ast_lean (env : Type_check.env) effect_info ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o = let defs = remove_imports defs 0 in let regs = State.find_registers defs in - let register_refs = match regs with [] -> empty | _ -> doc_reg_info env regs in - let types, fundefs = doc_defs (initial_context env) defs in + let global = { effect_info } in + let register_refs = match regs with [] -> empty | _ -> doc_reg_info env global regs in + let types, fundefs = doc_defs (initial_context env global) defs in print o (types ^^ register_refs ^^ fundefs); () diff --git a/src/sail_lean_backend/sail_plugin_lean.ml b/src/sail_lean_backend/sail_plugin_lean.ml index 9217b9706..32666d12b 100644 --- a/src/sail_lean_backend/sail_plugin_lean.ml +++ b/src/sail_lean_backend/sail_plugin_lean.ml @@ -190,15 +190,15 @@ let create_lake_project (out_name : string) default_sail_dir = output_string project_main "open Sail\n\n"; project_main -let output (out_name : string) env ast default_sail_dir = +let output (out_name : string) env effect_info ast default_sail_dir = let project_main = create_lake_project out_name default_sail_dir in (* Uncomment for debug output of the Sail code after the rewrite passes *) (* Pretty_print_sail.output_ast stdout (Type_check.strip_ast ast); *) - Pretty_print_lean.pp_ast_lean env ast project_main; + Pretty_print_lean.pp_ast_lean env effect_info ast project_main; close_out project_main let lean_target out_name { default_sail_dir; ctx; ast; effect_info; env; _ } = let out_name = match out_name with Some f -> f | None -> "out" in - output out_name env ast default_sail_dir + output out_name env effect_info ast default_sail_dir let _ = Target.register ~name:"lean" ~options:lean_options ~rewrites:lean_rewrites ~asserts_termination:true lean_target diff --git a/test/lean/bitfield.expected.lean b/test/lean/bitfield.expected.lean index 40b741b78..ba2722f7b 100644 --- a/test/lean/bitfield.expected.lean +++ b/test/lean/bitfield.expected.lean @@ -31,7 +31,7 @@ def _update_cr_type_bits (v : (BitVec 8)) (x : (BitVec 8)) : (BitVec 8) := (Sail.BitVec.updateSubrange v (HSub.hSub 8 1) 0 x) def _set_cr_type_bits (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 8)) : SailM Unit := do - let r ← (reg_deref r_ref) + let r := (← (reg_deref r_ref)) writeRegRef r_ref (_update_cr_type_bits r v) def _get_cr_type_CR0 (v : (BitVec 8)) : (BitVec 4) := @@ -41,7 +41,7 @@ def _update_cr_type_CR0 (v : (BitVec 8)) (x : (BitVec 4)) : (BitVec 8) := (Sail.BitVec.updateSubrange v 7 4 x) def _set_cr_type_CR0 (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 4)) : SailM Unit := do - let r ← (reg_deref r_ref) + let r := (← (reg_deref r_ref)) writeRegRef r_ref (_update_cr_type_CR0 r v) def _get_cr_type_CR1 (v : (BitVec 8)) : (BitVec 2) := @@ -51,7 +51,7 @@ def _update_cr_type_CR1 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) := (Sail.BitVec.updateSubrange v 3 2 x) def _set_cr_type_CR1 (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 2)) : SailM Unit := do - let r ← (reg_deref r_ref) + let r := (← (reg_deref r_ref)) writeRegRef r_ref (_update_cr_type_CR1 r v) def _get_cr_type_CR3 (v : (BitVec 8)) : (BitVec 2) := @@ -61,7 +61,7 @@ def _update_cr_type_CR3 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) := (Sail.BitVec.updateSubrange v 1 0 x) def _set_cr_type_CR3 (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 2)) : SailM Unit := do - let r ← (reg_deref r_ref) + let r := (← (reg_deref r_ref)) writeRegRef r_ref (_update_cr_type_CR3 r v) def _get_cr_type_GT (v : (BitVec 8)) : (BitVec 1) := @@ -71,7 +71,7 @@ def _update_cr_type_GT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) := (Sail.BitVec.updateSubrange v 6 6 x) def _set_cr_type_GT (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 1)) : SailM Unit := do - let r ← (reg_deref r_ref) + let r := (← (reg_deref r_ref)) writeRegRef r_ref (_update_cr_type_GT r v) def _get_cr_type_LT (v : (BitVec 8)) : (BitVec 1) := @@ -81,9 +81,9 @@ def _update_cr_type_LT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) := (Sail.BitVec.updateSubrange v 7 7 x) def _set_cr_type_LT (r_ref : RegisterRef RegisterType (BitVec 8)) (v : (BitVec 1)) : SailM Unit := do - let r ← (reg_deref r_ref) + let r := (← (reg_deref r_ref)) writeRegRef r_ref (_update_cr_type_LT r v) def initialize_registers : SailM Unit := do - writeReg R (undefined_cr_type ()) + writeReg R (← (undefined_cr_type ())) diff --git a/test/lean/struct.expected.lean b/test/lean/struct.expected.lean index f4703e14c..48168af0b 100644 --- a/test/lean/struct.expected.lean +++ b/test/lean/struct.expected.lean @@ -6,6 +6,20 @@ structure My_struct where field1 : Int field2 : (BitVec 1) +inductive Register : Type where + | r + deriving DecidableEq, Hashable +open Register + +abbrev RegisterType : Register → Type + | .r => My_struct + +abbrev SailM := PreSailM RegisterType + +open RegisterRef +instance : Inhabited (RegisterRef RegisterType My_struct) where + default := .Reg r + def undefined_My_struct (lit : Unit) : SailM My_struct := do (pure { field1 := (← sorry) field2 := (← sorry) }) @@ -28,6 +42,6 @@ def mk_struct (i : Int) (b : (BitVec 1)) : My_struct := def undef_struct (x : (BitVec 1)) : SailM My_struct := do ((undefined_My_struct ()) : SailM My_struct) -def initialize_registers : Unit := - () +def initialize_registers : SailM Unit := do + writeReg r (← (undefined_My_struct ())) diff --git a/test/lean/struct.sail b/test/lean/struct.sail index d99b68aa6..25df1168a 100644 --- a/test/lean/struct.sail +++ b/test/lean/struct.sail @@ -7,6 +7,8 @@ struct My_struct = { field2 : bit, } +register r : My_struct + val struct_field2 : My_struct -> bit function struct_field2(s) = { s.field2