diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 5e8b94c5ab..d2f2fd7357 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1367,30 +1367,42 @@ defmodule EXLA.Defn do ## Computation helpers - defp sort_computation(op, type, arg_typespecs, %{builder: %EXLA.MLIR.Function{} = function}) do + defp sort_computation(operator, type, arg_typespecs, %{ + builder: %EXLA.MLIR.Function{} = function + }) do {region, [lhs, rhs | _]} = Function.push_region(function, arg_typespecs) typespec = Typespec.tensor({:pred, 8}, {}) - op = - cond do - Nx.Type.integer?(type) -> - apply(Value, op, [lhs, rhs, typespec]) - - op == :less -> - is_nan = Value.is_nan(rhs, typespec) - Value.bitwise_or(is_nan, Value.less(lhs, rhs, typespec), typespec) - - op == :greater -> - is_nan = Value.is_nan(lhs, typespec) - Value.bitwise_or(is_nan, Value.greater(lhs, rhs, typespec), typespec) + {lhs, rhs} = + if Nx.Type.integer?(type) do + {lhs, rhs} + else + {sort_computation_canonicalize_float(lhs), sort_computation_canonicalize_float(rhs)} end + op = apply(Value, operator, [lhs, rhs, typespec, [total_order: true]]) + Value.return(function, [op]) Function.pop_region(function) region end + defp sort_computation_canonicalize_float(%Value{function: func} = op) do + # Standardize the representation of NaNs (-NaN, NaN) and zeros (-0, 0). + # See https://github.com/google/jax/blob/e81c82605f0e1813080cfe1037d043b27b38291d/jax/_src/lax/lax.py#L4248-L4253 + + op_typespec = Value.get_typespec(op) + + zero = Value.constant(func, [0], Typespec.to_shape(op_typespec, {})) + zeros = Value.constant(func, [0], op_typespec) + nans = Value.constant(func, [:nan], op_typespec) + + pred_typespec = Typespec.tensor({:pred, 8}, {}) + op = Value.select(Value.equal(op, zero, pred_typespec), zeros, op, op_typespec) + Value.select(Value.is_nan(op, pred_typespec), nans, op, op_typespec) + end + defp op_computation( op, arg_typespecs, diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index e38d09fc0b..2b25c6f8f6 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -54,31 +54,40 @@ defmodule EXLA.MLIR.Value do } for {op, direction} <- @bin_comparison_ops do - def unquote(op)(%Value{function: func} = lhs, %Value{function: func} = rhs, typespec) do - compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction)) + def unquote(op)( + %Value{function: func} = lhs, + %Value{function: func} = rhs, + typespec, + opts \\ [] + ) do + compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction), opts[:total_order]) end end - defp compare_and_return_bool(func, lhs, rhs, typespec, direction) do + defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do %{type: lhs_type} = get_typespec(lhs) %{type: rhs_type} = get_typespec(rhs) comparison_type = cond do Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) -> - attr_comparison_type(:float) + [compare_type: attr_comparison_type(:float)] Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) -> - attr_comparison_type(:float) + attr = + if total_order? do + attr_comparison_type(:totalorder) + else + attr_comparison_type(:float) + end + + [compare_type: attr] true -> - attr_comparison_type(:notype) + [] end - attributes = [ - comparison_direction: attr_comparison_direction(direction), - compare_type: comparison_type - ] + attributes = [comparison_direction: attr_comparison_direction(direction)] ++ comparison_type result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})]) @@ -1072,7 +1081,7 @@ defmodule EXLA.MLIR.Value do defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne], do: attr_enum("stablehlo", "comparison_direction", value) - defp attr_comparison_type(value) when value in [:float, :totalorder, :notype], + defp attr_comparison_type(value) when value in [:float, :totalorder], do: attr_enum("stablehlo", "comparison_type", value) defp attr_precision(value) when value in [:default, :high, :highest],