Skip to content

Commit

Permalink
Fix argmax/argmin behaviour with NaNs (#1499)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Jun 2, 2024
1 parent 7990b7e commit bb61c58
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 69 deletions.
95 changes: 55 additions & 40 deletions exla/lib/exla/lib.ex
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ defmodule EXLA.Lib do
def argmax(builder, op, type, opts \\ [])

def argmax(%Function{} = builder, %Value{} = op, type, opts) do
argmin_or_max(builder, op, false, type, opts)
argmin_or_max(builder, op, :max, type, opts)
end

@doc """
Expand All @@ -49,37 +49,43 @@ defmodule EXLA.Lib do
def argmin(builder, op, type, opts \\ [])

def argmin(%Function{} = builder, %Value{} = op, type, opts) do
argmin_or_max(builder, op, true, type, opts)
argmin_or_max(builder, op, :min, type, opts)
end

defp argmin_or_max(builder, %Value{} = op, is_min?, type, opts) do
defp argmin_or_max(builder, %Value{} = op, variant, type, opts) do
tie_break = opts[:tie_break] || :low
keep_axis = opts[:keep_axis] || false
axis = opts[:axis]

op_typespec = Value.get_typespec(op)

{op, op_typespec} =
if axis == nil and Nx.rank(op_typespec.shape) != 1 do
# When no axis is given, we flatten the tensor and reduce over
# the first axis
typespec = Typespec.to_shape(op_typespec, {Nx.size(op_typespec.shape)})
{Value.reshape(op, typespec), typespec}
else
{op, op_typespec}
end

axis = axis || 0

init_value =
if is_min?,
do: max_number(builder, op_typespec.type),
else: min_number(builder, op_typespec.type)
case variant do
:min -> max_number(builder, op_typespec.type)
:max -> min_number(builder, op_typespec.type)
end

axis = opts[:axis]
index_init_value = Value.constant(builder, [0], Typespec.tensor(type, {}))
iota = iota(builder, axis, Typespec.to_type(op_typespec, type))
reduction = create_min_max_computation(builder, op_typespec.type, type, is_min?, tie_break)
reduction = create_min_max_computation(builder, op_typespec.type, type, variant, tie_break)

dims =
if axis do
[axis]
else
Nx.axes(op_typespec.shape)
end

shape = remove_axes(op_typespec.shape, dims)
shape = Tuple.delete_at(op_typespec.shape, axis)
typespecs = [Typespec.tensor(op_typespec.type, shape), Typespec.tensor(type, shape)]

[_, result] =
Value.reduce(reduction, [init_value, index_init_value], [op, iota], dims, typespecs)
Value.reduce(reduction, [init_value, index_init_value], [op, iota], [axis], typespecs)

if keep_axis do
Value.reshape(result, Typespec.tensor(type, put_elem(op_typespec.shape, axis, 1)))
Expand All @@ -88,13 +94,7 @@ defmodule EXLA.Lib do
end
end

defp remove_axes(shape, axes) do
axes
|> Enum.reverse()
|> Enum.reduce(shape, &Tuple.delete_at(&2, &1))
end

defp create_min_max_computation(%Function{} = function, type, index_type, is_min?, tie_break) do
defp create_min_max_computation(%Function{} = function, type, index_type, variant, tie_break) do
arg_typespecs = [
Typespec.tensor(type, {}),
Typespec.tensor(index_type, {}),
Expand All @@ -109,27 +109,42 @@ defmodule EXLA.Lib do
value_typespec = Typespec.tensor(type, {})
idx_typespec = Typespec.tensor(index_type, {})

cmp =
if is_min?,
do: Value.less_equal(lhs_value, rhs_value, pred_typespec),
else: Value.greater_equal(lhs_value, rhs_value, pred_typespec)
comparator =
case variant do
:min -> &Value.less/3
:max -> &Value.greater/3
end

# Pick lhs if strictly before or if it is NaN
pick_lhs_value =
Value.bitwise_or(
comparator.(lhs_value, rhs_value, pred_typespec),
Value.is_nan(lhs_value, pred_typespec),
pred_typespec
)

max = Value.select(cmp, lhs_value, rhs_value, value_typespec)
arg_max = Value.select(cmp, lhs_index, rhs_index, idx_typespec)
max = Value.select(pick_lhs_value, lhs_value, rhs_value, value_typespec)

arg_max =
idx_comparator =
case tie_break do
:low ->
eq? = Value.equal(lhs_value, rhs_value, pred_typespec)
id = Value.min(lhs_index, rhs_index, idx_typespec)
Value.select(eq?, id, arg_max, idx_typespec)

:high ->
eq? = Value.equal(lhs_value, rhs_value, pred_typespec)
id = Value.max(lhs_index, rhs_index, idx_typespec)
Value.select(eq?, id, arg_max, idx_typespec)
:low -> &Value.less/3
:high -> &Value.greater/3
end

# If lhs and rhs are equal (and not NaN), then pick index based on tie_break
pick_lhs_idx =
Value.bitwise_or(
pick_lhs_value,
Value.bitwise_and(
Value.equal(lhs_value, rhs_value, pred_typespec),
idx_comparator.(lhs_index, rhs_index, pred_typespec),
pred_typespec
),
pred_typespec
)

arg_max = Value.select(pick_lhs_idx, lhs_index, rhs_index, idx_typespec)

Value.return(function, [max, arg_max])
Function.pop_region(function)
region
Expand Down
29 changes: 3 additions & 26 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -157,34 +157,11 @@ defmodule EXLA.MLIR.Value do
end
end

def is_nan(%Value{function: func} = operand, out_typespec) do
%{type: type} = get_typespec(operand)

def is_nan(%Value{} = operand, out_typespec) do
typespec = Typespec.to_type(out_typespec, {:pred, 8})

result =
cond do
Nx.Type.complex?(type) ->
float_typespec = Typespec.to_type(typespec, complex_part_type(type))
real = real(operand, float_typespec)
imag = imag(operand, float_typespec)
is_nan_real = is_nan(real, typespec)
is_nan_imag = is_nan(imag, typespec)
bitwise_or(is_nan_real, is_nan_imag, typespec)

Nx.Type.integer?(type) ->
# Integers are never nan. We use inequality to make sure
# the operand is still a part of the computation
not_equal(operand, operand, typespec)

true ->
result_types = typespecs_to_mlir_types([typespec])
is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!()
is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!()
is_not_inf = bitwise_not(is_inf, typespec)
is_not_finite = bitwise_not(is_finite, typespec)
bitwise_and(is_not_inf, is_not_finite, typespec)
end
# Only NaN is not equal to itself
result = not_equal(operand, operand, typespec)

if out_typespec.type == typespec.type do
result
Expand Down
18 changes: 18 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -10008,6 +10008,15 @@ defmodule Nx do
1
>
If the tensor includes any NaNs, returns the index of any of them
(NaNs are not equal, hence tie-break does not apply):
iex> Nx.argmax(Nx.tensor([2.0, :nan, 4.0]))
#Nx.Tensor<
s64
1
>
### Aggregating over an axis
iex> t = Nx.tensor([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]])
Expand Down Expand Up @@ -10147,6 +10156,15 @@ defmodule Nx do
0
>
If the tensor includes any NaNs, returns the index of any of them
(NaNs are not equal, hence tie-break does not apply):
iex> Nx.argmin(Nx.tensor([2.0, :nan, 4.0]))
#Nx.Tensor<
s64
1
>
### Aggregating over an axis
iex> t = Nx.tensor([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]])
Expand Down
2 changes: 1 addition & 1 deletion nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,7 @@ defmodule Nx.BinaryBackend do
bin, {i, cur_extreme_x, cur_extreme_i} ->
x = binary_to_number(bin, type)

if cur_extreme_x == :first or comparator.(x, cur_extreme_x) do
if cur_extreme_x == :first or x == :nan or comparator.(x, cur_extreme_x) do
{i, {i + 1, x, i}}
else
{cur_extreme_i, {i + 1, cur_extreme_x, cur_extreme_i}}
Expand Down
4 changes: 2 additions & 2 deletions nx/test/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,7 @@ defmodule NxTest do
[:nan, 0, 1]
])

assert Nx.argmin(t, axis: 1) == Nx.tensor([0, 0, 0, 0, 2, 2, 1, 1, 0, 0, 0, 0])
assert Nx.argmin(t, axis: 1) == Nx.tensor([0, 1, 0, 0, 2, 1, 1, 1, 0, 1, 0, 0])
end

test "raises for invalid :tie_break option" do
Expand Down Expand Up @@ -1475,7 +1475,7 @@ defmodule NxTest do
[:nan, 0, 1]
])

assert Nx.argmax(t, axis: 1) == Nx.tensor([1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0])
assert Nx.argmax(t, axis: 1) == Nx.tensor([1, 1, 2, 2, 0, 1, 0, 0, 0, 1, 0, 0])
end
end

Expand Down

0 comments on commit bb61c58

Please sign in to comment.