From c7cdd847a0813897d6ddb7ba7b315a5ff8e73f1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Thu, 23 May 2024 21:24:22 +0200 Subject: [PATCH] Create HLO token in the computation only when needed (#1494) --- exla/lib/exla/defn.ex | 132 +++++++++++++++++++++++++++++++----------- 1 file changed, 97 insertions(+), 35 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 36170de1af..86cceb4b5d 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -60,6 +60,7 @@ defmodule EXLA.Defn do compile_options, used_buffers, used_inputs, + _stream = true, comp_fun ) @@ -258,7 +259,7 @@ defmodule EXLA.Defn do callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client)) {executable, used_inputs, outputs, outfeed, :ok, debug?} = - compile(client, key, vars, fun, compile_options, 0, [], callback) + compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback) fn [args] -> {time, lock} = @@ -357,7 +358,17 @@ defmodule EXLA.Defn do ## Compile - defp compile(client, key, vars, fun, options, used_buffers, used_inputs, to_computation) do + defp compile( + client, + key, + vars, + fun, + options, + used_buffers, + used_inputs, + stream?, + to_computation + ) do {{expr_cache_fun, comp_cache_fun}, options} = case Keyword.pop(options, :cache, true) do {true, options} -> @@ -385,7 +396,7 @@ defmodule EXLA.Defn do {eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} = :timer.tc(fn -> - expr_cache_fun.({key, args_key}, fn -> + expr_cache_fun.({key, args_key, lazy_transfers}, fn -> expr = fun.(vars) inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers) {expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}} @@ -432,10 +443,16 @@ defmodule EXLA.Defn do end) EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder -> + # Only create the token when we know it will actually be + # used, that is: streaming, lazy transfers or hooks outfeed = - outfeed - |> Outfeed.with_token(Value.create_token(builder)) - |> Outfeed.add_infeeds(builder, reverse_infeeds) + if stream? or reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do + outfeed + |> Outfeed.with_token(Value.create_token(builder)) + |> Outfeed.add_infeeds(builder, reverse_infeeds) + else + outfeed + end expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1) @@ -520,19 +537,30 @@ defmodule EXLA.Defn do cache ) do [initial_arg, _arg, pred, body] = args - initial_with_token = {get_token(cache), initial_arg} - {initial, cache} = recur_composite(initial_with_token, state, cache) + initial = + if token = get_token(cache) do + {token, initial_arg} + else + initial_arg + end + + {initial, cache} = recur_composite(initial, state, cache) {pred_computation, cache} = mlir_while_computation(pred, initial, {:pred, 8}, state, cache) {body_computation, cache} = mlir_while_computation(body, initial, :with_token, state, cache) - [token | results] = + results = Value.while(function, pred_computation, body_computation, List.flatten(initial)) - result = wrap_tuple_result(results, initial_arg) - - {result, update_token(cache, token)} + if get_token(cache) do + [token | results] = results + result = wrap_tuple_result(results, initial_arg) + {result, update_token(cache, token)} + else + result = wrap_tuple_result(results, initial_arg) + {result, cache} + end end defp cached_recur_operator(:cond, %T{data: %Expr{args: args}} = t, state, cache) do @@ -688,16 +716,19 @@ defmodule EXLA.Defn do {computation, cache} %{} -> - {computation, cache} = token_computation("optional", call_args, expr, state, cache) + {computation, cache} = optional_computation("optional", call_args, expr, state, cache) {computation, Map.put(cache, key, computation)} end - typespecs = [Typespec.token() | container_to_typespecs(expr)] - - [token | result] = - Value.call(state.builder, [get_token(cache) | call_args], call_body, typespecs) - - {wrap_tuple_result(result, expr), update_token(cache, token)} + if token = get_token(cache) do + typespecs = [Typespec.token() | container_to_typespecs(expr)] + [token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs) + {wrap_tuple_result(result, expr), update_token(cache, token)} + else + typespecs = container_to_typespecs(expr) + result = Value.call(state.builder, call_args, call_body, typespecs) + {wrap_tuple_result(result, expr), cache} + end end defp cached_recur_operator(:attach_token, %T{data: %Expr{args: [token, expr]}}, state, cache) do @@ -1553,7 +1584,17 @@ defmodule EXLA.Defn do defp mlir_while_computation(expr, initial, type, state, cache) do arg_typespecs = Enum.map(List.flatten(initial), &Value.get_typespec/1) - {region, [arg_token | arg_params]} = Function.push_region(state.builder, arg_typespecs) + {region, args} = Function.push_region(state.builder, arg_typespecs) + + outer_token = get_token(cache) + + {inner_token, arg_params} = + if outer_token do + [arg_token | arg_params] = args + {arg_token, arg_params} + else + {nil, args} + end params = Enum.with_index(arg_params, &{&2, &1}) @@ -1570,11 +1611,15 @@ defmodule EXLA.Defn do expr end - {res, comp_cache} = recur_composite(expr, & &1, state, reset_token(cache, arg_token)) + {res, comp_cache} = recur_composite(expr, & &1, state, reset_token(cache, inner_token)) res = if type == :with_token do - [get_token(comp_cache) | List.flatten(res)] + if outer_token do + [get_token(comp_cache) | List.flatten(res)] + else + List.flatten(res) + end else Enum.map(res, &to_type(&1, type)) end @@ -1585,21 +1630,34 @@ defmodule EXLA.Defn do {region, merge_outfeed(cache, comp_cache)} end - defp token_computation(name, args, expr, %{builder: %Function{}} = state, cache) do + defp optional_computation(name, args, expr, %{builder: %Function{}} = state, cache) do %Function{module: module, name: name} = subbuilder(state.builder, name) - token_typespec = Typespec.token() arg_typespecs = Enum.map(args, &Value.get_typespec/1) out_typespecs = container_to_typespecs(expr) - function = - EXLA.MLIR.Module.add_function(module, name, [token_typespec | arg_typespecs], [ - token_typespec | out_typespecs - ]) + outer_token = get_token(cache) + token_typespec = Typespec.token() + + {arg_typespecs, out_typespecs} = + if outer_token do + {[token_typespec | arg_typespecs], [token_typespec | out_typespecs]} + else + {arg_typespecs, out_typespecs} + end - [arg_token | tail] = EXLA.MLIR.Function.get_arguments(function) + function = EXLA.MLIR.Module.add_function(module, name, arg_typespecs, out_typespecs) + args = EXLA.MLIR.Function.get_arguments(function) - params = Enum.with_index(tail, fn param, i -> {i, param} end) + {inner_token, args} = + if outer_token do + [arg_token | args] = args + {arg_token, args} + else + {nil, args} + end + + params = Enum.with_index(args, fn param, i -> {i, param} end) state = %{ state @@ -1608,9 +1666,13 @@ defmodule EXLA.Defn do scope_ids: Tree.scope_ids(expr) } - {res, comp_cache} = recur_composite(expr, state, reset_token(cache, arg_token)) + {res, comp_cache} = recur_composite(expr, state, reset_token(cache, inner_token)) - Value.return(function, [get_token(comp_cache) | List.flatten(res)]) + if outer_token do + Value.return(function, [get_token(comp_cache) | List.flatten(res)]) + else + Value.return(function, List.flatten(res)) + end {function, merge_outfeed(cache, comp_cache)} end @@ -1786,10 +1848,10 @@ defmodule EXLA.Defn do out_typespecs = container_to_typespecs(on_true) - in_token = get_token(cache) + outer_token = get_token(cache) result_typespecs = - if in_token do + if outer_token do [Typespec.token() | out_typespecs] else out_typespecs @@ -1799,7 +1861,7 @@ defmodule EXLA.Defn do {false_computation, cache} = to_mlir_if_branch(on_false, false_ids, state, cache) if_results = Value.if_op(pred_op, true_computation, false_computation, result_typespecs) - if in_token do + if outer_token do [token | results] = if_results {wrap_tuple_result(results, on_true), update_token(cache, token)} else