Skip to content

Commit

Permalink
fix: Nx.Random.shuffle repeating a single value in certain cases on G…
Browse files Browse the repository at this point in the history
…PU (#1552)

Co-authored-by: Jonatan Klosko <[email protected]>
  • Loading branch information
2 people authored and josevalim committed Nov 16, 2024
1 parent d64ba46 commit d21aca5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
38 changes: 25 additions & 13 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1560,30 +1560,42 @@ defmodule EXLA.Defn do

## Computation helpers

defp sort_computation(op, type, arg_typespecs, %{builder: %EXLA.MLIR.Function{} = function}) do
defp sort_computation(operator, type, arg_typespecs, %{
builder: %EXLA.MLIR.Function{} = function
}) do
{region, [lhs, rhs | _]} = Function.push_region(function, arg_typespecs)

typespec = Typespec.tensor({:pred, 8}, {})

op =
cond do
Nx.Type.integer?(type) ->
apply(Value, op, [lhs, rhs, typespec])

op == :less ->
is_nan = Value.is_nan(rhs, typespec)
Value.bitwise_or(is_nan, Value.less(lhs, rhs, typespec), typespec)

op == :greater ->
is_nan = Value.is_nan(lhs, typespec)
Value.bitwise_or(is_nan, Value.greater(lhs, rhs, typespec), typespec)
{lhs, rhs} =
if Nx.Type.integer?(type) do
{lhs, rhs}
else
{sort_computation_canonicalize_float(lhs), sort_computation_canonicalize_float(rhs)}
end

op = apply(Value, operator, [lhs, rhs, typespec, [total_order: true]])

Value.return(function, [op])
Function.pop_region(function)
region
end

defp sort_computation_canonicalize_float(%Value{function: func} = op) do
# Standardize the representation of NaNs (-NaN, NaN) and zeros (-0, 0).
# See https://github.com/google/jax/blob/e81c82605f0e1813080cfe1037d043b27b38291d/jax/_src/lax/lax.py#L4248-L4253

op_typespec = Value.get_typespec(op)

zero = Value.constant(func, [0], Typespec.to_shape(op_typespec, {}))
zeros = Value.constant(func, [0], op_typespec)
nans = Value.constant(func, [:nan], op_typespec)

pred_typespec = Typespec.tensor({:pred, 8}, {})
op = Value.select(Value.equal(op, zero, pred_typespec), zeros, op, op_typespec)
Value.select(Value.is_nan(op, pred_typespec), nans, op, op_typespec)
end

defp op_computation(
op,
arg_typespecs,
Expand Down
31 changes: 20 additions & 11 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -54,31 +54,40 @@ defmodule EXLA.MLIR.Value do
}

for {op, direction} <- @bin_comparison_ops do
def unquote(op)(%Value{function: func} = lhs, %Value{function: func} = rhs, typespec) do
compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction))
def unquote(op)(
%Value{function: func} = lhs,
%Value{function: func} = rhs,
typespec,
opts \\ []
) do
compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction), opts[:total_order])
end
end

defp compare_and_return_bool(func, lhs, rhs, typespec, direction) do
defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do
%{type: lhs_type} = get_typespec(lhs)
%{type: rhs_type} = get_typespec(rhs)

comparison_type =
cond do
Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) ->
attr_comparison_type(:float)
[compare_type: attr_comparison_type(:float)]

Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) ->
attr_comparison_type(:float)
attr =
if total_order? do
attr_comparison_type(:totalorder)
else
attr_comparison_type(:float)
end

[compare_type: attr]

true ->
attr_comparison_type(:notype)
[]
end

attributes = [
comparison_direction: attr_comparison_direction(direction),
compare_type: comparison_type
]
attributes = [comparison_direction: attr_comparison_direction(direction)] ++ comparison_type

result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})])

Expand Down Expand Up @@ -1072,7 +1081,7 @@ defmodule EXLA.MLIR.Value do
defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne],
do: attr_enum("stablehlo", "comparison_direction", value)

defp attr_comparison_type(value) when value in [:float, :totalorder, :notype],
defp attr_comparison_type(value) when value in [:float, :totalorder],
do: attr_enum("stablehlo", "comparison_type", value)

defp attr_precision(value) when value in [:default, :high, :highest],
Expand Down

0 comments on commit d21aca5

Please sign in to comment.