Skip to content

Commit

Permalink
Create HLO token in the computation only when needed (#1494)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored May 23, 2024
1 parent 773f4d6 commit c7cdd84
Showing 1 changed file with 97 additions and 35 deletions.
132 changes: 97 additions & 35 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ defmodule EXLA.Defn do
compile_options,
used_buffers,
used_inputs,
_stream = true,
comp_fun
)

Expand Down Expand Up @@ -258,7 +259,7 @@ defmodule EXLA.Defn do
callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client))

{executable, used_inputs, outputs, outfeed, :ok, debug?} =
compile(client, key, vars, fun, compile_options, 0, [], callback)
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback)

fn [args] ->
{time, lock} =
Expand Down Expand Up @@ -357,7 +358,17 @@ defmodule EXLA.Defn do

## Compile

defp compile(client, key, vars, fun, options, used_buffers, used_inputs, to_computation) do
defp compile(
client,
key,
vars,
fun,
options,
used_buffers,
used_inputs,
stream?,
to_computation
) do
{{expr_cache_fun, comp_cache_fun}, options} =
case Keyword.pop(options, :cache, true) do
{true, options} ->
Expand Down Expand Up @@ -385,7 +396,7 @@ defmodule EXLA.Defn do

{eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} =
:timer.tc(fn ->
expr_cache_fun.({key, args_key}, fn ->
expr_cache_fun.({key, args_key, lazy_transfers}, fn ->
expr = fun.(vars)
inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers)
{expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}}
Expand Down Expand Up @@ -432,10 +443,16 @@ defmodule EXLA.Defn do
end)

EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder ->
# Only create the token when we know it will actually be
# used, that is: streaming, lazy transfers or hooks
outfeed =
outfeed
|> Outfeed.with_token(Value.create_token(builder))
|> Outfeed.add_infeeds(builder, reverse_infeeds)
if stream? or reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do
outfeed
|> Outfeed.with_token(Value.create_token(builder))
|> Outfeed.add_infeeds(builder, reverse_infeeds)
else
outfeed
end

expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1)

Expand Down Expand Up @@ -520,19 +537,30 @@ defmodule EXLA.Defn do
cache
) do
[initial_arg, _arg, pred, body] = args
initial_with_token = {get_token(cache), initial_arg}

{initial, cache} = recur_composite(initial_with_token, state, cache)
initial =
if token = get_token(cache) do
{token, initial_arg}
else
initial_arg
end

{initial, cache} = recur_composite(initial, state, cache)

{pred_computation, cache} = mlir_while_computation(pred, initial, {:pred, 8}, state, cache)
{body_computation, cache} = mlir_while_computation(body, initial, :with_token, state, cache)

[token | results] =
results =
Value.while(function, pred_computation, body_computation, List.flatten(initial))

result = wrap_tuple_result(results, initial_arg)

{result, update_token(cache, token)}
if get_token(cache) do
[token | results] = results
result = wrap_tuple_result(results, initial_arg)
{result, update_token(cache, token)}
else
result = wrap_tuple_result(results, initial_arg)
{result, cache}
end
end

defp cached_recur_operator(:cond, %T{data: %Expr{args: args}} = t, state, cache) do
Expand Down Expand Up @@ -688,16 +716,19 @@ defmodule EXLA.Defn do
{computation, cache}

%{} ->
{computation, cache} = token_computation("optional", call_args, expr, state, cache)
{computation, cache} = optional_computation("optional", call_args, expr, state, cache)
{computation, Map.put(cache, key, computation)}
end

typespecs = [Typespec.token() | container_to_typespecs(expr)]

[token | result] =
Value.call(state.builder, [get_token(cache) | call_args], call_body, typespecs)

{wrap_tuple_result(result, expr), update_token(cache, token)}
if token = get_token(cache) do
typespecs = [Typespec.token() | container_to_typespecs(expr)]
[token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs)
{wrap_tuple_result(result, expr), update_token(cache, token)}
else
typespecs = container_to_typespecs(expr)
result = Value.call(state.builder, call_args, call_body, typespecs)
{wrap_tuple_result(result, expr), cache}
end
end

defp cached_recur_operator(:attach_token, %T{data: %Expr{args: [token, expr]}}, state, cache) do
Expand Down Expand Up @@ -1553,7 +1584,17 @@ defmodule EXLA.Defn do
defp mlir_while_computation(expr, initial, type, state, cache) do
arg_typespecs = Enum.map(List.flatten(initial), &Value.get_typespec/1)

{region, [arg_token | arg_params]} = Function.push_region(state.builder, arg_typespecs)
{region, args} = Function.push_region(state.builder, arg_typespecs)

outer_token = get_token(cache)

{inner_token, arg_params} =
if outer_token do
[arg_token | arg_params] = args
{arg_token, arg_params}
else
{nil, args}
end

params = Enum.with_index(arg_params, &{&2, &1})

Expand All @@ -1570,11 +1611,15 @@ defmodule EXLA.Defn do
expr
end

{res, comp_cache} = recur_composite(expr, & &1, state, reset_token(cache, arg_token))
{res, comp_cache} = recur_composite(expr, & &1, state, reset_token(cache, inner_token))

res =
if type == :with_token do
[get_token(comp_cache) | List.flatten(res)]
if outer_token do
[get_token(comp_cache) | List.flatten(res)]
else
List.flatten(res)
end
else
Enum.map(res, &to_type(&1, type))
end
Expand All @@ -1585,21 +1630,34 @@ defmodule EXLA.Defn do
{region, merge_outfeed(cache, comp_cache)}
end

defp token_computation(name, args, expr, %{builder: %Function{}} = state, cache) do
defp optional_computation(name, args, expr, %{builder: %Function{}} = state, cache) do
%Function{module: module, name: name} = subbuilder(state.builder, name)

token_typespec = Typespec.token()
arg_typespecs = Enum.map(args, &Value.get_typespec/1)
out_typespecs = container_to_typespecs(expr)

function =
EXLA.MLIR.Module.add_function(module, name, [token_typespec | arg_typespecs], [
token_typespec | out_typespecs
])
outer_token = get_token(cache)
token_typespec = Typespec.token()

{arg_typespecs, out_typespecs} =
if outer_token do
{[token_typespec | arg_typespecs], [token_typespec | out_typespecs]}
else
{arg_typespecs, out_typespecs}
end

[arg_token | tail] = EXLA.MLIR.Function.get_arguments(function)
function = EXLA.MLIR.Module.add_function(module, name, arg_typespecs, out_typespecs)
args = EXLA.MLIR.Function.get_arguments(function)

params = Enum.with_index(tail, fn param, i -> {i, param} end)
{inner_token, args} =
if outer_token do
[arg_token | args] = args
{arg_token, args}
else
{nil, args}
end

params = Enum.with_index(args, fn param, i -> {i, param} end)

state = %{
state
Expand All @@ -1608,9 +1666,13 @@ defmodule EXLA.Defn do
scope_ids: Tree.scope_ids(expr)
}

{res, comp_cache} = recur_composite(expr, state, reset_token(cache, arg_token))
{res, comp_cache} = recur_composite(expr, state, reset_token(cache, inner_token))

Value.return(function, [get_token(comp_cache) | List.flatten(res)])
if outer_token do
Value.return(function, [get_token(comp_cache) | List.flatten(res)])
else
Value.return(function, List.flatten(res))
end

{function, merge_outfeed(cache, comp_cache)}
end
Expand Down Expand Up @@ -1786,10 +1848,10 @@ defmodule EXLA.Defn do

out_typespecs = container_to_typespecs(on_true)

in_token = get_token(cache)
outer_token = get_token(cache)

result_typespecs =
if in_token do
if outer_token do
[Typespec.token() | out_typespecs]
else
out_typespecs
Expand All @@ -1799,7 +1861,7 @@ defmodule EXLA.Defn do
{false_computation, cache} = to_mlir_if_branch(on_false, false_ids, state, cache)
if_results = Value.if_op(pred_op, true_computation, false_computation, result_typespecs)

if in_token do
if outer_token do
[token | results] = if_results
{wrap_tuple_result(results, on_true), update_token(cache, token)}
else
Expand Down

0 comments on commit c7cdd84

Please sign in to comment.