diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 494512674d..56e6ac5c37 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -61,6 +61,7 @@ defmodule EXLA.Defn do compile_options, used_buffers, used_inputs, + _stream = true, comp_fun ) @@ -255,7 +256,7 @@ defmodule EXLA.Defn do def __compile__(key, vars, fun, options) do {run_options, compile_options} = Keyword.pop(options, :run_options, []) - {executable, used_inputs, outputs, outfeed, :ok, debug?} = + {:ok, {executable, {used_inputs, outputs, outfeed, debug?}}} = compile_executable(key, vars, fun, compile_options) fn [args] -> @@ -282,9 +283,8 @@ defmodule EXLA.Defn do end def export_executable(fun, vars, options) do - fun - |> compile_executable(vars, &Function.identity/1, Keyword.delete(options, :run_options)) - |> elem(0) + {:ok, {executable, _}} = + compile_executable(fun, vars, & &1, Keyword.delete(options, :run_options)) end defp compile_executable(key, vars, fun, compile_options) do @@ -298,7 +298,9 @@ defmodule EXLA.Defn do callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client)) {executable, used_inputs, outputs, outfeed, :ok, debug?} = - compile(client, key, vars, fun, compile_options, 0, [], callback) + compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback) + + {:ok, {executable, {used_inputs, outputs, outfeed, debug?}}} end defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, options) do @@ -396,7 +398,17 @@ defmodule EXLA.Defn do ## Compile - defp compile(client, key, vars, fun, options, used_buffers, used_inputs, to_computation) do + defp compile( + client, + key, + vars, + fun, + options, + used_buffers, + used_inputs, + stream?, + to_computation + ) do {{expr_cache_fun, comp_cache_fun}, options} = case Keyword.pop(options, :cache, true) do {true, options} -> @@ -425,7 +437,7 @@ defmodule EXLA.Defn do {eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} = :timer.tc(fn -> - expr_cache_fun.({key, args_key}, fn -> + expr_cache_fun.({key, args_key, lazy_transfers}, fn -> expr = fun.(vars) inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers) {expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}} @@ -473,13 +485,15 @@ defmodule EXLA.Defn do end) EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder -> - builder = %EXLA.MLIR.Function{builder | runtime: runtime} - + # Only create the token when we know it will actually be + # used, that is: streaming, lazy transfers or hooks outfeed = - if runtime != :iree do + if stream? or reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do outfeed |> Outfeed.with_token(Value.create_token(builder)) |> Outfeed.add_infeeds(builder, reverse_infeeds) + else + outfeed end expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1) @@ -598,29 +612,28 @@ defmodule EXLA.Defn do ) do [initial_arg, _arg, pred, body] = args - initial_with_token = - if state.builder.runtime == :iree do - initial_arg + initial = + if token = get_token(cache) do + {token, initial_arg} else - [get_token(cache), initial_arg] + initial_arg end - {initial, cache} = recur_composite(initial_with_token, state, cache) + {initial, cache} = recur_composite(initial, state, cache) {pred_computation, cache} = mlir_while_computation(pred, initial, {:pred, 8}, state, cache) {body_computation, cache} = mlir_while_computation(body, initial, :with_token, state, cache) - output = Value.while(function, pred_computation, body_computation, List.flatten(initial)) - - case state.builder.runtime do - :iree -> - result = wrap_tuple_result(output, initial_arg) - {result, cache} + results = + Value.while(function, pred_computation, body_computation, List.flatten(initial)) - _ -> - [token | results] = output - result = wrap_tuple_result(results, initial_arg) - {result, update_token(cache, token)} + if get_token(cache) do + [token | results] = results + result = wrap_tuple_result(results, initial_arg) + {result, update_token(cache, token)} + else + result = wrap_tuple_result(results, initial_arg) + {result, cache} end end @@ -777,22 +790,18 @@ defmodule EXLA.Defn do {computation, cache} %{} -> - {computation, cache} = token_computation("optional", call_args, expr, state, cache) + {computation, cache} = optional_computation("optional", call_args, expr, state, cache) {computation, Map.put(cache, key, computation)} end - if state.builder.runtime == :iree do + if token = get_token(cache) do + typespecs = [Typespec.token() | container_to_typespecs(expr)] + [token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs) + {wrap_tuple_result(result, expr), update_token(cache, token)} + else typespecs = container_to_typespecs(expr) - result = Value.call(state.builder, call_args, call_body, typespecs) {wrap_tuple_result(result, expr), cache} - else - typespecs = [Typespec.token() | container_to_typespecs(expr)] - - [token | result] = - Value.call(state.builder, [get_token(cache) | call_args], call_body, typespecs) - - {wrap_tuple_result(result, expr), update_token(cache, token)} end end @@ -1658,7 +1667,17 @@ defmodule EXLA.Defn do defp mlir_while_computation(expr, initial, type, state, cache) do arg_typespecs = Enum.map(List.flatten(initial), &Value.get_typespec/1) - {region, [arg_token | arg_params]} = Function.push_region(state.builder, arg_typespecs) + {region, args} = Function.push_region(state.builder, arg_typespecs) + + outer_token = get_token(cache) + + {inner_token, arg_params} = + if outer_token do + [arg_token | arg_params] = args + {arg_token, arg_params} + else + {nil, args} + end params = Enum.with_index(arg_params, &{&2, &1}) @@ -1675,11 +1694,15 @@ defmodule EXLA.Defn do expr end - {res, comp_cache} = recur_composite(expr, & &1, state, reset_token(cache, arg_token)) + {res, comp_cache} = recur_composite(expr, & &1, state, reset_token(cache, inner_token)) res = if type == :with_token do - [get_token(comp_cache) | List.flatten(res)] + if outer_token do + [get_token(comp_cache) | List.flatten(res)] + else + List.flatten(res) + end else Enum.map(res, &to_type(&1, type)) end @@ -1690,47 +1713,35 @@ defmodule EXLA.Defn do {region, merge_outfeed(cache, comp_cache)} end - defp token_computation( - name, - args, - expr, - %{builder: %Function{runtime: runtime}} = state, - cache - ) do + defp optional_computation(name, args, expr, %{builder: %Function{}} = state, cache) do %Function{module: module, name: name} = subbuilder(state.builder, name) - token_typespec = Typespec.token() arg_typespecs = Enum.map(args, &Value.get_typespec/1) out_typespecs = container_to_typespecs(expr) - in_types = - if runtime == :iree do - arg_typespecs - else - [token_typespec | arg_typespecs] - end + outer_token = get_token(cache) + token_typespec = Typespec.token() - out_types = - if runtime == :iree do - out_typespecs + {arg_typespecs, out_typespecs} = + if outer_token do + {[token_typespec | arg_typespecs], [token_typespec | out_typespecs]} else - [token_typespec | out_typespecs] + {arg_typespecs, out_typespecs} end - function = - EXLA.MLIR.Module.add_function(module, name, in_types, out_types) + function = EXLA.MLIR.Module.add_function(module, name, arg_typespecs, out_typespecs) + args = EXLA.MLIR.Function.get_arguments(function) - function = %{function | runtime: runtime} - - [arg_token | tail] = EXLA.MLIR.Function.get_arguments(function) - - params = - if runtime == :iree do - Enum.with_index([arg_token | tail], fn param, i -> {i, param} end) + {inner_token, args} = + if outer_token do + [arg_token | args] = args + {arg_token, args} else - Enum.with_index(tail, fn param, i -> {i, param} end) + {nil, args} end + params = Enum.with_index(args, fn param, i -> {i, param} end) + state = %{ state | builder: function, @@ -1738,15 +1749,15 @@ defmodule EXLA.Defn do scope_ids: Tree.scope_ids(expr) } - if runtime == :iree do - {res, comp_cache} = recur_composite(expr, state, cache) - Value.func_return(function, List.flatten(res)) - {function, merge_outfeed(cache, comp_cache)} - else - {res, comp_cache} = recur_composite(expr, state, reset_token(cache, arg_token)) + {res, comp_cache} = recur_composite(expr, state, reset_token(cache, inner_token)) + + if outer_token do Value.func_return(function, [get_token(comp_cache) | List.flatten(res)]) - {function, merge_outfeed(cache, comp_cache)} + else + Value.func_return(function, List.flatten(res)) end + + {function, merge_outfeed(cache, comp_cache)} end # The cache is built on top of call args because we need to handle pred/u8. @@ -1920,10 +1931,10 @@ defmodule EXLA.Defn do out_typespecs = container_to_typespecs(on_true) - in_token = get_token(cache) + outer_token = get_token(cache) result_typespecs = - if in_token do + if outer_token do [Typespec.token() | out_typespecs] else out_typespecs @@ -1933,7 +1944,7 @@ defmodule EXLA.Defn do {false_computation, cache} = to_mlir_if_branch(on_false, false_ids, state, cache) if_results = Value.if_op(pred_op, true_computation, false_computation, result_typespecs) - if in_token do + if outer_token do [token | results] = if_results {wrap_tuple_result(results, on_true), update_token(cache, token)} else diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index f9a62822a9..115c865a0f 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -1515,26 +1515,6 @@ defmodule EXLA.Defn.ExprTest do end end - defn while_inside_if(pred, x) do - if pred do - {x, _} = - while {x, i = 0}, i < 10 do - {x, i + 1} - end - - x - else - x - end - end - - test "while inside if" do - assert %{a: a, b: b} = while_inside_if(1, %{a: 1, b: 2.0}) - assert_all_close(a, 1) - assert_all_close(b, 2.0) - end - end - describe "reduce" do defn reduce(t), do: Nx.reduce(t, 1, fn a, b -> a * b end) defn reduce_keep(t), do: Nx.reduce(t, 1, [keep_axes: true], fn a, b -> a * b end)