Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into pv-feat/iree-compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed May 23, 2024
2 parents 28b4bbd + c7cdd84 commit 2f4e879
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 96 deletions.
163 changes: 87 additions & 76 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ defmodule EXLA.Defn do
compile_options,
used_buffers,
used_inputs,
_stream = true,
comp_fun
)

Expand Down Expand Up @@ -255,7 +256,7 @@ defmodule EXLA.Defn do
def __compile__(key, vars, fun, options) do
{run_options, compile_options} = Keyword.pop(options, :run_options, [])

{executable, used_inputs, outputs, outfeed, :ok, debug?} =
{:ok, {executable, {used_inputs, outputs, outfeed, debug?}}} =
compile_executable(key, vars, fun, compile_options)

fn [args] ->
Expand All @@ -282,9 +283,8 @@ defmodule EXLA.Defn do
end

def export_executable(fun, vars, options) do
fun
|> compile_executable(vars, &Function.identity/1, Keyword.delete(options, :run_options))
|> elem(0)
{:ok, {executable, _}} =
compile_executable(fun, vars, & &1, Keyword.delete(options, :run_options))
end

defp compile_executable(key, vars, fun, compile_options) do
Expand All @@ -298,7 +298,9 @@ defmodule EXLA.Defn do
callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client))

{executable, used_inputs, outputs, outfeed, :ok, debug?} =
compile(client, key, vars, fun, compile_options, 0, [], callback)
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback)

{:ok, {executable, {used_inputs, outputs, outfeed, debug?}}}
end

defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, options) do
Expand Down Expand Up @@ -396,7 +398,17 @@ defmodule EXLA.Defn do

## Compile

defp compile(client, key, vars, fun, options, used_buffers, used_inputs, to_computation) do
defp compile(
client,
key,
vars,
fun,
options,
used_buffers,
used_inputs,
stream?,
to_computation
) do
{{expr_cache_fun, comp_cache_fun}, options} =
case Keyword.pop(options, :cache, true) do
{true, options} ->
Expand Down Expand Up @@ -425,7 +437,7 @@ defmodule EXLA.Defn do

{eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} =
:timer.tc(fn ->
expr_cache_fun.({key, args_key}, fn ->
expr_cache_fun.({key, args_key, lazy_transfers}, fn ->
expr = fun.(vars)
inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers)
{expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}}
Expand Down Expand Up @@ -473,13 +485,15 @@ defmodule EXLA.Defn do
end)

EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder ->
builder = %EXLA.MLIR.Function{builder | runtime: runtime}

# Only create the token when we know it will actually be
# used, that is: streaming, lazy transfers or hooks
outfeed =
if runtime != :iree do
if stream? or reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do
outfeed
|> Outfeed.with_token(Value.create_token(builder))
|> Outfeed.add_infeeds(builder, reverse_infeeds)
else
outfeed
end

expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1)
Expand Down Expand Up @@ -598,29 +612,28 @@ defmodule EXLA.Defn do
) do
[initial_arg, _arg, pred, body] = args

initial_with_token =
if state.builder.runtime == :iree do
initial_arg
initial =
if token = get_token(cache) do
{token, initial_arg}
else
[get_token(cache), initial_arg]
initial_arg
end

{initial, cache} = recur_composite(initial_with_token, state, cache)
{initial, cache} = recur_composite(initial, state, cache)

{pred_computation, cache} = mlir_while_computation(pred, initial, {:pred, 8}, state, cache)
{body_computation, cache} = mlir_while_computation(body, initial, :with_token, state, cache)

output = Value.while(function, pred_computation, body_computation, List.flatten(initial))

case state.builder.runtime do
:iree ->
result = wrap_tuple_result(output, initial_arg)
{result, cache}
results =
Value.while(function, pred_computation, body_computation, List.flatten(initial))

_ ->
[token | results] = output
result = wrap_tuple_result(results, initial_arg)
{result, update_token(cache, token)}
if get_token(cache) do
[token | results] = results
result = wrap_tuple_result(results, initial_arg)
{result, update_token(cache, token)}
else
result = wrap_tuple_result(results, initial_arg)
{result, cache}
end
end

Expand Down Expand Up @@ -777,22 +790,18 @@ defmodule EXLA.Defn do
{computation, cache}

%{} ->
{computation, cache} = token_computation("optional", call_args, expr, state, cache)
{computation, cache} = optional_computation("optional", call_args, expr, state, cache)
{computation, Map.put(cache, key, computation)}
end

if state.builder.runtime == :iree do
if token = get_token(cache) do
typespecs = [Typespec.token() | container_to_typespecs(expr)]
[token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs)
{wrap_tuple_result(result, expr), update_token(cache, token)}
else
typespecs = container_to_typespecs(expr)

result = Value.call(state.builder, call_args, call_body, typespecs)
{wrap_tuple_result(result, expr), cache}
else
typespecs = [Typespec.token() | container_to_typespecs(expr)]

[token | result] =
Value.call(state.builder, [get_token(cache) | call_args], call_body, typespecs)

{wrap_tuple_result(result, expr), update_token(cache, token)}
end
end

Expand Down Expand Up @@ -1658,7 +1667,17 @@ defmodule EXLA.Defn do
defp mlir_while_computation(expr, initial, type, state, cache) do
arg_typespecs = Enum.map(List.flatten(initial), &Value.get_typespec/1)

{region, [arg_token | arg_params]} = Function.push_region(state.builder, arg_typespecs)
{region, args} = Function.push_region(state.builder, arg_typespecs)

outer_token = get_token(cache)

{inner_token, arg_params} =
if outer_token do
[arg_token | arg_params] = args
{arg_token, arg_params}
else
{nil, args}
end

params = Enum.with_index(arg_params, &{&2, &1})

Expand All @@ -1675,11 +1694,15 @@ defmodule EXLA.Defn do
expr
end

{res, comp_cache} = recur_composite(expr, & &1, state, reset_token(cache, arg_token))
{res, comp_cache} = recur_composite(expr, & &1, state, reset_token(cache, inner_token))

res =
if type == :with_token do
[get_token(comp_cache) | List.flatten(res)]
if outer_token do
[get_token(comp_cache) | List.flatten(res)]
else
List.flatten(res)
end
else
Enum.map(res, &to_type(&1, type))
end
Expand All @@ -1690,63 +1713,51 @@ defmodule EXLA.Defn do
{region, merge_outfeed(cache, comp_cache)}
end

defp token_computation(
name,
args,
expr,
%{builder: %Function{runtime: runtime}} = state,
cache
) do
defp optional_computation(name, args, expr, %{builder: %Function{}} = state, cache) do
%Function{module: module, name: name} = subbuilder(state.builder, name)

token_typespec = Typespec.token()
arg_typespecs = Enum.map(args, &Value.get_typespec/1)
out_typespecs = container_to_typespecs(expr)

in_types =
if runtime == :iree do
arg_typespecs
else
[token_typespec | arg_typespecs]
end
outer_token = get_token(cache)
token_typespec = Typespec.token()

out_types =
if runtime == :iree do
out_typespecs
{arg_typespecs, out_typespecs} =
if outer_token do
{[token_typespec | arg_typespecs], [token_typespec | out_typespecs]}
else
[token_typespec | out_typespecs]
{arg_typespecs, out_typespecs}
end

function =
EXLA.MLIR.Module.add_function(module, name, in_types, out_types)
function = EXLA.MLIR.Module.add_function(module, name, arg_typespecs, out_typespecs)
args = EXLA.MLIR.Function.get_arguments(function)

function = %{function | runtime: runtime}

[arg_token | tail] = EXLA.MLIR.Function.get_arguments(function)

params =
if runtime == :iree do
Enum.with_index([arg_token | tail], fn param, i -> {i, param} end)
{inner_token, args} =
if outer_token do
[arg_token | args] = args
{arg_token, args}
else
Enum.with_index(tail, fn param, i -> {i, param} end)
{nil, args}
end

params = Enum.with_index(args, fn param, i -> {i, param} end)

state = %{
state
| builder: function,
params: Map.new(params),
scope_ids: Tree.scope_ids(expr)
}

if runtime == :iree do
{res, comp_cache} = recur_composite(expr, state, cache)
Value.func_return(function, List.flatten(res))
{function, merge_outfeed(cache, comp_cache)}
else
{res, comp_cache} = recur_composite(expr, state, reset_token(cache, arg_token))
{res, comp_cache} = recur_composite(expr, state, reset_token(cache, inner_token))

if outer_token do
Value.func_return(function, [get_token(comp_cache) | List.flatten(res)])
{function, merge_outfeed(cache, comp_cache)}
else
Value.func_return(function, List.flatten(res))
end

{function, merge_outfeed(cache, comp_cache)}
end

# The cache is built on top of call args because we need to handle pred/u8.
Expand Down Expand Up @@ -1920,10 +1931,10 @@ defmodule EXLA.Defn do

out_typespecs = container_to_typespecs(on_true)

in_token = get_token(cache)
outer_token = get_token(cache)

result_typespecs =
if in_token do
if outer_token do
[Typespec.token() | out_typespecs]
else
out_typespecs
Expand All @@ -1933,7 +1944,7 @@ defmodule EXLA.Defn do
{false_computation, cache} = to_mlir_if_branch(on_false, false_ids, state, cache)
if_results = Value.if_op(pred_op, true_computation, false_computation, result_typespecs)

if in_token do
if outer_token do
[token | results] = if_results
{wrap_tuple_result(results, on_true), update_token(cache, token)}
else
Expand Down
20 changes: 0 additions & 20 deletions exla/test/exla/defn/expr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1515,26 +1515,6 @@ defmodule EXLA.Defn.ExprTest do
end
end

defn while_inside_if(pred, x) do
if pred do
{x, _} =
while {x, i = 0}, i < 10 do
{x, i + 1}
end

x
else
x
end
end

test "while inside if" do
assert %{a: a, b: b} = while_inside_if(1, %{a: 1, b: 2.0})
assert_all_close(a, 1)
assert_all_close(b, 2.0)
end
end

describe "reduce" do
defn reduce(t), do: Nx.reduce(t, 1, fn a, b -> a * b end)
defn reduce_keep(t), do: Nx.reduce(t, 1, [keep_axes: true], fn a, b -> a * b end)
Expand Down

0 comments on commit 2f4e879

Please sign in to comment.