From bb61c58ac2bbc371bca518b9269f1c023fe81f1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Sun, 2 Jun 2024 10:57:36 +0200 Subject: [PATCH] Fix argmax/argmin behaviour with NaNs (#1499) --- exla/lib/exla/lib.ex | 95 +++++++++++++++++++++---------------- exla/lib/exla/mlir/value.ex | 29 ++--------- nx/lib/nx.ex | 18 +++++++ nx/lib/nx/binary_backend.ex | 2 +- nx/test/nx_test.exs | 4 +- 5 files changed, 79 insertions(+), 69 deletions(-) diff --git a/exla/lib/exla/lib.ex b/exla/lib/exla/lib.ex index 2b54e3a5da..686d98d7df 100644 --- a/exla/lib/exla/lib.ex +++ b/exla/lib/exla/lib.ex @@ -34,7 +34,7 @@ defmodule EXLA.Lib do def argmax(builder, op, type, opts \\ []) def argmax(%Function{} = builder, %Value{} = op, type, opts) do - argmin_or_max(builder, op, false, type, opts) + argmin_or_max(builder, op, :max, type, opts) end @doc """ @@ -49,37 +49,43 @@ defmodule EXLA.Lib do def argmin(builder, op, type, opts \\ []) def argmin(%Function{} = builder, %Value{} = op, type, opts) do - argmin_or_max(builder, op, true, type, opts) + argmin_or_max(builder, op, :min, type, opts) end - defp argmin_or_max(builder, %Value{} = op, is_min?, type, opts) do + defp argmin_or_max(builder, %Value{} = op, variant, type, opts) do tie_break = opts[:tie_break] || :low keep_axis = opts[:keep_axis] || false + axis = opts[:axis] op_typespec = Value.get_typespec(op) + {op, op_typespec} = + if axis == nil and Nx.rank(op_typespec.shape) != 1 do + # When no axis is given, we flatten the tensor and reduce over + # the first axis + typespec = Typespec.to_shape(op_typespec, {Nx.size(op_typespec.shape)}) + {Value.reshape(op, typespec), typespec} + else + {op, op_typespec} + end + + axis = axis || 0 + init_value = - if is_min?, - do: max_number(builder, op_typespec.type), - else: min_number(builder, op_typespec.type) + case variant do + :min -> max_number(builder, op_typespec.type) + :max -> min_number(builder, op_typespec.type) + end - axis = opts[:axis] index_init_value = Value.constant(builder, [0], Typespec.tensor(type, {})) iota = iota(builder, axis, Typespec.to_type(op_typespec, type)) - reduction = create_min_max_computation(builder, op_typespec.type, type, is_min?, tie_break) + reduction = create_min_max_computation(builder, op_typespec.type, type, variant, tie_break) - dims = - if axis do - [axis] - else - Nx.axes(op_typespec.shape) - end - - shape = remove_axes(op_typespec.shape, dims) + shape = Tuple.delete_at(op_typespec.shape, axis) typespecs = [Typespec.tensor(op_typespec.type, shape), Typespec.tensor(type, shape)] [_, result] = - Value.reduce(reduction, [init_value, index_init_value], [op, iota], dims, typespecs) + Value.reduce(reduction, [init_value, index_init_value], [op, iota], [axis], typespecs) if keep_axis do Value.reshape(result, Typespec.tensor(type, put_elem(op_typespec.shape, axis, 1))) @@ -88,13 +94,7 @@ defmodule EXLA.Lib do end end - defp remove_axes(shape, axes) do - axes - |> Enum.reverse() - |> Enum.reduce(shape, &Tuple.delete_at(&2, &1)) - end - - defp create_min_max_computation(%Function{} = function, type, index_type, is_min?, tie_break) do + defp create_min_max_computation(%Function{} = function, type, index_type, variant, tie_break) do arg_typespecs = [ Typespec.tensor(type, {}), Typespec.tensor(index_type, {}), @@ -109,27 +109,42 @@ defmodule EXLA.Lib do value_typespec = Typespec.tensor(type, {}) idx_typespec = Typespec.tensor(index_type, {}) - cmp = - if is_min?, - do: Value.less_equal(lhs_value, rhs_value, pred_typespec), - else: Value.greater_equal(lhs_value, rhs_value, pred_typespec) + comparator = + case variant do + :min -> &Value.less/3 + :max -> &Value.greater/3 + end + + # Pick lhs if strictly before or if it is NaN + pick_lhs_value = + Value.bitwise_or( + comparator.(lhs_value, rhs_value, pred_typespec), + Value.is_nan(lhs_value, pred_typespec), + pred_typespec + ) - max = Value.select(cmp, lhs_value, rhs_value, value_typespec) - arg_max = Value.select(cmp, lhs_index, rhs_index, idx_typespec) + max = Value.select(pick_lhs_value, lhs_value, rhs_value, value_typespec) - arg_max = + idx_comparator = case tie_break do - :low -> - eq? = Value.equal(lhs_value, rhs_value, pred_typespec) - id = Value.min(lhs_index, rhs_index, idx_typespec) - Value.select(eq?, id, arg_max, idx_typespec) - - :high -> - eq? = Value.equal(lhs_value, rhs_value, pred_typespec) - id = Value.max(lhs_index, rhs_index, idx_typespec) - Value.select(eq?, id, arg_max, idx_typespec) + :low -> &Value.less/3 + :high -> &Value.greater/3 end + # If lhs and rhs are equal (and not NaN), then pick index based on tie_break + pick_lhs_idx = + Value.bitwise_or( + pick_lhs_value, + Value.bitwise_and( + Value.equal(lhs_value, rhs_value, pred_typespec), + idx_comparator.(lhs_index, rhs_index, pred_typespec), + pred_typespec + ), + pred_typespec + ) + + arg_max = Value.select(pick_lhs_idx, lhs_index, rhs_index, idx_typespec) + Value.return(function, [max, arg_max]) Function.pop_region(function) region diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 61208274b0..06b14fd099 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -157,34 +157,11 @@ defmodule EXLA.MLIR.Value do end end - def is_nan(%Value{function: func} = operand, out_typespec) do - %{type: type} = get_typespec(operand) - + def is_nan(%Value{} = operand, out_typespec) do typespec = Typespec.to_type(out_typespec, {:pred, 8}) - result = - cond do - Nx.Type.complex?(type) -> - float_typespec = Typespec.to_type(typespec, complex_part_type(type)) - real = real(operand, float_typespec) - imag = imag(operand, float_typespec) - is_nan_real = is_nan(real, typespec) - is_nan_imag = is_nan(imag, typespec) - bitwise_or(is_nan_real, is_nan_imag, typespec) - - Nx.Type.integer?(type) -> - # Integers are never nan. We use inequality to make sure - # the operand is still a part of the computation - not_equal(operand, operand, typespec) - - true -> - result_types = typespecs_to_mlir_types([typespec]) - is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!() - is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!() - is_not_inf = bitwise_not(is_inf, typespec) - is_not_finite = bitwise_not(is_finite, typespec) - bitwise_and(is_not_inf, is_not_finite, typespec) - end + # Only NaN is not equal to itself + result = not_equal(operand, operand, typespec) if out_typespec.type == typespec.type do result diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index a8fcee7488..aca74c01c7 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -10008,6 +10008,15 @@ defmodule Nx do 1 > + If the tensor includes any NaNs, returns the index of any of them + (NaNs are not equal, hence tie-break does not apply): + + iex> Nx.argmax(Nx.tensor([2.0, :nan, 4.0])) + #Nx.Tensor< + s64 + 1 + > + ### Aggregating over an axis iex> t = Nx.tensor([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) @@ -10147,6 +10156,15 @@ defmodule Nx do 0 > + If the tensor includes any NaNs, returns the index of any of them + (NaNs are not equal, hence tie-break does not apply): + + iex> Nx.argmin(Nx.tensor([2.0, :nan, 4.0])) + #Nx.Tensor< + s64 + 1 + > + ### Aggregating over an axis iex> t = Nx.tensor([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 32279b8115..cbf1539e9b 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -1461,7 +1461,7 @@ defmodule Nx.BinaryBackend do bin, {i, cur_extreme_x, cur_extreme_i} -> x = binary_to_number(bin, type) - if cur_extreme_x == :first or comparator.(x, cur_extreme_x) do + if cur_extreme_x == :first or x == :nan or comparator.(x, cur_extreme_x) do {i, {i + 1, x, i}} else {cur_extreme_i, {i + 1, cur_extreme_x, cur_extreme_i}} diff --git a/nx/test/nx_test.exs b/nx/test/nx_test.exs index 2da1ca42ed..28ee017e0d 100644 --- a/nx/test/nx_test.exs +++ b/nx/test/nx_test.exs @@ -1443,7 +1443,7 @@ defmodule NxTest do [:nan, 0, 1] ]) - assert Nx.argmin(t, axis: 1) == Nx.tensor([0, 0, 0, 0, 2, 2, 1, 1, 0, 0, 0, 0]) + assert Nx.argmin(t, axis: 1) == Nx.tensor([0, 1, 0, 0, 2, 1, 1, 1, 0, 1, 0, 0]) end test "raises for invalid :tie_break option" do @@ -1475,7 +1475,7 @@ defmodule NxTest do [:nan, 0, 1] ]) - assert Nx.argmax(t, axis: 1) == Nx.tensor([1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0]) + assert Nx.argmax(t, axis: 1) == Nx.tensor([1, 1, 2, 2, 0, 1, 0, 0, 0, 1, 0, 0]) end end