diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 86cceb4b5d..482ad540df 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -237,7 +237,7 @@ defmodule EXLA.Defn do output = wrap_tuple_result(acc, acc_typespec) outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder) - Value.return(builder, output) + Value.func_return(builder, output) {{input_typespecs, input_indexes}, outfeed} end @@ -307,7 +307,7 @@ defmodule EXLA.Defn do {res, cache} = recur_flatten(expr, state, new_cache(outfeed)) outfeed = cache |> get_outfeed() |> Outfeed.close(function) - Value.return(function, res) + Value.func_return(function, res) {:ok, outfeed} end @@ -433,6 +433,15 @@ defmodule EXLA.Defn do comp_arg_typespecs = for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec + outputs = + if stream? do + # The computation returns the final accumulator value + {_chunk_result, acc} = outputs + acc + else + outputs + end + out_typespecs = [outputs] |> Nx.Defn.Composite.flatten_list() @@ -1669,9 +1678,9 @@ defmodule EXLA.Defn do {res, comp_cache} = recur_composite(expr, state, reset_token(cache, inner_token)) if outer_token do - Value.return(function, [get_token(comp_cache) | List.flatten(res)]) + Value.func_return(function, [get_token(comp_cache) | List.flatten(res)]) else - Value.return(function, List.flatten(res)) + Value.func_return(function, List.flatten(res)) end {function, merge_outfeed(cache, comp_cache)} diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 62ec3ff100..61208274b0 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -125,57 +125,71 @@ defmodule EXLA.MLIR.Value do end end - def is_infinity(%Value{function: func} = operand, typespec) do + def is_infinity(%Value{function: func} = operand, out_typespec) do %{type: type} = get_typespec(operand) - typespec = Typespec.to_type(typespec, {:pred, 8}) + typespec = Typespec.to_type(out_typespec, {:pred, 8}) - cond do - Nx.Type.complex?(type) -> - float_typespec = Typespec.to_type(typespec, complex_part_type(type)) - real = real(operand, float_typespec) - imag = imag(operand, float_typespec) - is_inf_real = is_infinity(real, typespec) - is_inf_imag = is_infinity(imag, typespec) - bitwise_or(is_inf_real, is_inf_imag, typespec) - - Nx.Type.integer?(type) -> - # Integers are never infinity. We use inequality to make sure - # the operand is still a part of the computation - not_equal(operand, operand, typespec) + result = + cond do + Nx.Type.complex?(type) -> + float_typespec = Typespec.to_type(typespec, complex_part_type(type)) + real = real(operand, float_typespec) + imag = imag(operand, float_typespec) + is_inf_real = is_infinity(real, typespec) + is_inf_imag = is_infinity(imag, typespec) + bitwise_or(is_inf_real, is_inf_imag, typespec) + + Nx.Type.integer?(type) -> + # Integers are never infinity. We use inequality to make sure + # the operand is still a part of the computation + not_equal(operand, operand, typespec) + + true -> + result_types = typespecs_to_mlir_types([typespec]) + op(func, "chlo.is_inf", [operand], result_types) |> one!() + end - true -> - result_types = typespecs_to_mlir_types([typespec]) - op(func, "chlo.is_inf", [operand], result_types) |> one!() + if out_typespec.type == typespec.type do + result + else + convert(result, out_typespec) end end - def is_nan(%Value{function: func} = operand, typespec) do + def is_nan(%Value{function: func} = operand, out_typespec) do %{type: type} = get_typespec(operand) - typespec = Typespec.to_type(typespec, {:pred, 8}) + typespec = Typespec.to_type(out_typespec, {:pred, 8}) - cond do - Nx.Type.complex?(type) -> - float_typespec = Typespec.to_type(typespec, complex_part_type(type)) - real = real(operand, float_typespec) - imag = imag(operand, float_typespec) - is_nan_real = is_nan(real, typespec) - is_nan_imag = is_nan(imag, typespec) - bitwise_or(is_nan_real, is_nan_imag, typespec) - - Nx.Type.integer?(type) -> - # Integers are never nan. We use inequality to make sure - # the operand is still a part of the computation - not_equal(operand, operand, typespec) + result = + cond do + Nx.Type.complex?(type) -> + float_typespec = Typespec.to_type(typespec, complex_part_type(type)) + real = real(operand, float_typespec) + imag = imag(operand, float_typespec) + is_nan_real = is_nan(real, typespec) + is_nan_imag = is_nan(imag, typespec) + bitwise_or(is_nan_real, is_nan_imag, typespec) + + Nx.Type.integer?(type) -> + # Integers are never nan. We use inequality to make sure + # the operand is still a part of the computation + not_equal(operand, operand, typespec) + + true -> + result_types = typespecs_to_mlir_types([typespec]) + is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!() + is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!() + is_not_inf = bitwise_not(is_inf, typespec) + is_not_finite = bitwise_not(is_finite, typespec) + bitwise_and(is_not_inf, is_not_finite, typespec) + end - true -> - result_types = typespecs_to_mlir_types([typespec]) - is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!() - is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!() - is_not_inf = bitwise_not(is_inf, typespec) - is_not_finite = bitwise_not(is_finite, typespec) - bitwise_and(is_not_inf, is_not_finite, typespec) + if out_typespec.type == typespec.type do + result + else + convert(result, out_typespec) end end @@ -706,6 +720,10 @@ defmodule EXLA.MLIR.Value do op(func, "stablehlo.while", initial, result_types, regions: regions) end + def func_return(func, values) when is_list(values) do + op(func, "func.return", values, []) + end + def return(func, values) when is_list(values) do op(func, "stablehlo.return", values, []) end diff --git a/exla/test/exla/executable_test.exs b/exla/test/exla/executable_test.exs index 28e276edfc..49f33f9280 100644 --- a/exla/test/exla/executable_test.exs +++ b/exla/test/exla/executable_test.exs @@ -160,7 +160,7 @@ defmodule EXLA.ExecutableFeedTest do assert res = Task.async(fn -> - run_one([], [], [Typespec.token()], fn b -> + run_one([], [], [t.typespec], fn b -> token = Value.create_token(b) {new_token, [val]} = Value.infeed(token, [t.typespec]) @@ -185,7 +185,7 @@ defmodule EXLA.ExecutableFeedTest do assert res = Task.async(fn -> - run_one([], [], [token_shape, t.typespec], fn b -> + run_one([], [], [t.typespec], fn b -> token = Value.create_token(b) arg_shapes = [token_shape, t.typespec] diff --git a/exla/test/support/exla_helpers.ex b/exla/test/support/exla_helpers.ex index db7689d144..971d097ab3 100644 --- a/exla/test/support/exla_helpers.ex +++ b/exla/test/support/exla_helpers.ex @@ -15,7 +15,7 @@ defmodule EXLAHelpers do fun |> apply([builder | params]) - |> then(&EXLA.MLIR.Value.return(builder, List.wrap(&1))) + |> then(&EXLA.MLIR.Value.func_return(builder, List.wrap(&1))) EXLA.MLIR.Module.compile( builder.module,