Skip to content

Commit

Permalink
Push side effects to rhs of let bindings.
Browse files Browse the repository at this point in the history
  • Loading branch information
maximebuyse committed Jan 13, 2025
1 parent 1258db5 commit c61f2d3
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 161 deletions.
60 changes: 59 additions & 1 deletion engine/lib/side_effect_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,61 @@ struct
collect_and_hoist_effects_object#visit_expr CollectContext.empty e
in
(lets_of_bindings lbs e, effects)

let has_side_effect =
object
inherit [_] Visitors.reduce as super
method zero = false
method plus l r = l || r

method! visit_expr' () =
function
| Assign _ | Return _ | Break _ | Continue _ -> true
| e -> super#visit_expr' () e
end

(* This visitor binds in `let_ = e in ()` all expressions
of type unit that are not already in a let binding.
This ensures that all side effects happen in the rhs of a let binding. *)
let bind_unit_return_position =
object (self)
inherit [_] Visitors.map as super

method! visit_expr in_let e =
match e.e with
| Let { monadic; rhs; lhs; body } ->
{
e with
e =
Let
{
monadic;
rhs = self#visit_expr true rhs;
lhs = self#visit_pat false lhs;
body = self#visit_expr false body;
};
}
| _ ->
let span = e.span in
if [%eq: expr'] e.e (U.unit_expr span).e then e
else if
[%eq: ty] e.typ U.unit_typ
&& (not in_let)
&& has_side_effect#visit_expr () e
then
{
e with
e =
Let
{
monadic = None;
rhs = self#visit_expr true e;
lhs = U.M.pat_PWild ~span ~typ:e.typ;
body = U.unit_expr span;
};
}
else super#visit_expr false e
end
end
end

Expand Down Expand Up @@ -538,7 +593,10 @@ struct
open ID

let dexpr (expr : A.expr) : B.expr =
Hoist.collect_and_hoist_effects expr |> fst |> dexpr
Hoist.collect_and_hoist_effects expr
|> fst
|> Hoist.bind_unit_return_position#visit_expr false
|> dexpr

[%%inline_defs "Item.*"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ let foo (v_LEN: usize) (arr: t_Array usize v_LEN) : usize =
(fun acc i ->
let acc:usize = acc in
let i:usize = i in
acc +! (arr.[ i ] <: usize) <: usize)
let acc:usize = acc +! (arr.[ i ] <: usize) in
acc)
in
acc

Expand Down
Loading

0 comments on commit c61f2d3

Please sign in to comment.