Skip to content

Commit

Permalink
feat: use func.return when returning from func.func (#1495)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonatan Kłosko <[email protected]>
  • Loading branch information
polvalente and jonatanklosko authored May 29, 2024
1 parent e0ed58a commit 7990b7e
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 47 deletions.
17 changes: 13 additions & 4 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ defmodule EXLA.Defn do
output = wrap_tuple_result(acc, acc_typespec)

outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder)
Value.return(builder, output)
Value.func_return(builder, output)

{{input_typespecs, input_indexes}, outfeed}
end
Expand Down Expand Up @@ -307,7 +307,7 @@ defmodule EXLA.Defn do
{res, cache} = recur_flatten(expr, state, new_cache(outfeed))
outfeed = cache |> get_outfeed() |> Outfeed.close(function)

Value.return(function, res)
Value.func_return(function, res)

{:ok, outfeed}
end
Expand Down Expand Up @@ -433,6 +433,15 @@ defmodule EXLA.Defn do
comp_arg_typespecs =
for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec

outputs =
if stream? do
# The computation returns the final accumulator value
{_chunk_result, acc} = outputs
acc
else
outputs
end

out_typespecs =
[outputs]
|> Nx.Defn.Composite.flatten_list()
Expand Down Expand Up @@ -1669,9 +1678,9 @@ defmodule EXLA.Defn do
{res, comp_cache} = recur_composite(expr, state, reset_token(cache, inner_token))

if outer_token do
Value.return(function, [get_token(comp_cache) | List.flatten(res)])
Value.func_return(function, [get_token(comp_cache) | List.flatten(res)])
else
Value.return(function, List.flatten(res))
Value.func_return(function, List.flatten(res))
end

{function, merge_outfeed(cache, comp_cache)}
Expand Down
98 changes: 58 additions & 40 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -125,57 +125,71 @@ defmodule EXLA.MLIR.Value do
end
end

def is_infinity(%Value{function: func} = operand, typespec) do
def is_infinity(%Value{function: func} = operand, out_typespec) do
%{type: type} = get_typespec(operand)

typespec = Typespec.to_type(typespec, {:pred, 8})
typespec = Typespec.to_type(out_typespec, {:pred, 8})

cond do
Nx.Type.complex?(type) ->
float_typespec = Typespec.to_type(typespec, complex_part_type(type))
real = real(operand, float_typespec)
imag = imag(operand, float_typespec)
is_inf_real = is_infinity(real, typespec)
is_inf_imag = is_infinity(imag, typespec)
bitwise_or(is_inf_real, is_inf_imag, typespec)

Nx.Type.integer?(type) ->
# Integers are never infinity. We use inequality to make sure
# the operand is still a part of the computation
not_equal(operand, operand, typespec)
result =
cond do
Nx.Type.complex?(type) ->
float_typespec = Typespec.to_type(typespec, complex_part_type(type))
real = real(operand, float_typespec)
imag = imag(operand, float_typespec)
is_inf_real = is_infinity(real, typespec)
is_inf_imag = is_infinity(imag, typespec)
bitwise_or(is_inf_real, is_inf_imag, typespec)

Nx.Type.integer?(type) ->
# Integers are never infinity. We use inequality to make sure
# the operand is still a part of the computation
not_equal(operand, operand, typespec)

true ->
result_types = typespecs_to_mlir_types([typespec])
op(func, "chlo.is_inf", [operand], result_types) |> one!()
end

true ->
result_types = typespecs_to_mlir_types([typespec])
op(func, "chlo.is_inf", [operand], result_types) |> one!()
if out_typespec.type == typespec.type do
result
else
convert(result, out_typespec)
end
end

def is_nan(%Value{function: func} = operand, typespec) do
def is_nan(%Value{function: func} = operand, out_typespec) do
%{type: type} = get_typespec(operand)

typespec = Typespec.to_type(typespec, {:pred, 8})
typespec = Typespec.to_type(out_typespec, {:pred, 8})

cond do
Nx.Type.complex?(type) ->
float_typespec = Typespec.to_type(typespec, complex_part_type(type))
real = real(operand, float_typespec)
imag = imag(operand, float_typespec)
is_nan_real = is_nan(real, typespec)
is_nan_imag = is_nan(imag, typespec)
bitwise_or(is_nan_real, is_nan_imag, typespec)

Nx.Type.integer?(type) ->
# Integers are never nan. We use inequality to make sure
# the operand is still a part of the computation
not_equal(operand, operand, typespec)
result =
cond do
Nx.Type.complex?(type) ->
float_typespec = Typespec.to_type(typespec, complex_part_type(type))
real = real(operand, float_typespec)
imag = imag(operand, float_typespec)
is_nan_real = is_nan(real, typespec)
is_nan_imag = is_nan(imag, typespec)
bitwise_or(is_nan_real, is_nan_imag, typespec)

Nx.Type.integer?(type) ->
# Integers are never nan. We use inequality to make sure
# the operand is still a part of the computation
not_equal(operand, operand, typespec)

true ->
result_types = typespecs_to_mlir_types([typespec])
is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!()
is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!()
is_not_inf = bitwise_not(is_inf, typespec)
is_not_finite = bitwise_not(is_finite, typespec)
bitwise_and(is_not_inf, is_not_finite, typespec)
end

true ->
result_types = typespecs_to_mlir_types([typespec])
is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!()
is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!()
is_not_inf = bitwise_not(is_inf, typespec)
is_not_finite = bitwise_not(is_finite, typespec)
bitwise_and(is_not_inf, is_not_finite, typespec)
if out_typespec.type == typespec.type do
result
else
convert(result, out_typespec)
end
end

Expand Down Expand Up @@ -706,6 +720,10 @@ defmodule EXLA.MLIR.Value do
op(func, "stablehlo.while", initial, result_types, regions: regions)
end

def func_return(func, values) when is_list(values) do
op(func, "func.return", values, [])
end

def return(func, values) when is_list(values) do
op(func, "stablehlo.return", values, [])
end
Expand Down
4 changes: 2 additions & 2 deletions exla/test/exla/executable_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ defmodule EXLA.ExecutableFeedTest do

assert res =
Task.async(fn ->
run_one([], [], [Typespec.token()], fn b ->
run_one([], [], [t.typespec], fn b ->
token = Value.create_token(b)

{new_token, [val]} = Value.infeed(token, [t.typespec])
Expand All @@ -185,7 +185,7 @@ defmodule EXLA.ExecutableFeedTest do

assert res =
Task.async(fn ->
run_one([], [], [token_shape, t.typespec], fn b ->
run_one([], [], [t.typespec], fn b ->
token = Value.create_token(b)

arg_shapes = [token_shape, t.typespec]
Expand Down
2 changes: 1 addition & 1 deletion exla/test/support/exla_helpers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ defmodule EXLAHelpers do

fun
|> apply([builder | params])
|> then(&EXLA.MLIR.Value.return(builder, List.wrap(&1)))
|> then(&EXLA.MLIR.Value.func_return(builder, List.wrap(&1)))

EXLA.MLIR.Module.compile(
builder.module,
Expand Down

0 comments on commit 7990b7e

Please sign in to comment.