From 88f1ff8acb1afa9e745128d43cc3eb0da39fbadc Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 29 Oct 2024 22:49:34 -0300 Subject: [PATCH 1/6] fix: Nx.Random.shuffle --- exla/test/exla/random_test.exs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/exla/test/exla/random_test.exs b/exla/test/exla/random_test.exs index 3b86fc64b9..67aa50c5f8 100644 --- a/exla/test/exla/random_test.exs +++ b/exla/test/exla/random_test.exs @@ -39,4 +39,23 @@ defmodule EXLA.NxRandomTest do ) end end + + @tag :cuda_required + test "regression on single-dimensional and multi-dimensional Random.shuffle" do + # these are put in the process dictionary, so it's thread-safe to do this + Nx.default_backend({EXLA.Backend, client: :cuda}) + Nx.Defn.default_options(compiler: EXLA, client: :cuda) + key = Nx.Random.key(127) + + t1 = Nx.iota({2, 100}) + t2 = Nx.iota({100}) + + {t1_shuffled_0, key} = Nx.Random.shuffle(key, t1, axis: 0) + {t1_shuffled_1, key} = Nx.Random.shuffle(key, t1, axis: 1) + {t2_shuffled, _key} = Nx.Random.shuffle(key, t2) + + assert_equal(Nx.sort(t1_shuffled_0, axis: 0), t1) + assert_equal(Nx.sort(t1_shuffled_1, axis: 1), t1) + assert_equal(Nx.sort(t2_shuffled), t2) + end end From de4c3612c4af89ae63d878d30a64932dd066a683 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 29 Oct 2024 22:49:34 -0300 Subject: [PATCH 2/6] fix: Nx.Random.shuffle --- exla/test/exla/random_test.exs | 19 +++++++++++++++++++ nx/lib/nx/random.ex | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/exla/test/exla/random_test.exs b/exla/test/exla/random_test.exs index 3b86fc64b9..67aa50c5f8 100644 --- a/exla/test/exla/random_test.exs +++ b/exla/test/exla/random_test.exs @@ -39,4 +39,23 @@ defmodule EXLA.NxRandomTest do ) end end + + @tag :cuda_required + test "regression on single-dimensional and multi-dimensional Random.shuffle" do + # these are put in the process dictionary, so it's thread-safe to do this + Nx.default_backend({EXLA.Backend, client: :cuda}) + Nx.Defn.default_options(compiler: EXLA, client: :cuda) + key = Nx.Random.key(127) + + t1 = Nx.iota({2, 100}) + t2 = Nx.iota({100}) + + {t1_shuffled_0, key} = Nx.Random.shuffle(key, t1, axis: 0) + {t1_shuffled_1, key} = Nx.Random.shuffle(key, t1, axis: 1) + {t2_shuffled, _key} = Nx.Random.shuffle(key, t2) + + assert_equal(Nx.sort(t1_shuffled_0, axis: 0), t1) + assert_equal(Nx.sort(t1_shuffled_1, axis: 1), t1) + assert_equal(Nx.sort(t2_shuffled), t2) + end end diff --git a/nx/lib/nx/random.ex b/nx/lib/nx/random.ex index 5b429517cb..f0de210548 100644 --- a/nx/lib/nx/random.ex +++ b/nx/lib/nx/random.ex @@ -830,7 +830,7 @@ defmodule Nx.Random do defnp sort_key_val(tensor, sort_keys, opts \\ []) do idx = Nx.argsort(sort_keys, axis: opts[:axis]) - Nx.take_along_axis(tensor, idx, axis: opts[:axis]) + Nx.take(tensor, idx, axis: opts[:axis]) end @choice_options """ From 39e954f9d9847f57444b25be86684564f8117229 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 29 Oct 2024 22:56:24 -0300 Subject: [PATCH 3/6] fix: switch between take and take_along_axis --- nx/lib/nx/random.ex | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/nx/lib/nx/random.ex b/nx/lib/nx/random.ex index f0de210548..e4f6c1e189 100644 --- a/nx/lib/nx/random.ex +++ b/nx/lib/nx/random.ex @@ -799,15 +799,21 @@ defmodule Nx.Random do axis = opts[:axis] if opts[:independent] do - shuffle_independent(key, tensor, axis: axis) + shuffle_independent(key, tensor, axis: axis, independent: true) else - {idx, key} = shuffle_independent(key, Nx.iota({Nx.axis_size(tensor, axis)}), axis: 0) + {idx, key} = + shuffle_independent(key, Nx.iota({Nx.axis_size(tensor, axis)}), + axis: 0, + independent: false + ) + {Nx.take(tensor, idx, axis: axis), key} end end defnp shuffle_independent(key, tensor, opts) do axis = opts[:axis] + independent = opts[:independent] # reference: https://github.com/google/jax/blob/838bc454895ed2086563301936fb0d6d852fd198/jax/_src/random.py#L437 exponent = 3 @@ -821,16 +827,25 @@ defmodule Nx.Random do while {i = 0, tensor, key}, i < num_rounds do keys = split(key) sort_keys = random_bits(keys[1], shape: tensor.shape) - tensor = sort_key_val(tensor, sort_keys, axis: axis) + tensor = sort_key_val(tensor, sort_keys, axis: axis, independent: independent) {i + 1, tensor, keys[0]} end {out, key} end - defnp sort_key_val(tensor, sort_keys, opts \\ []) do + deftransformp sort_key_val(tensor, sort_keys, opts \\ []) do idx = Nx.argsort(sort_keys, axis: opts[:axis]) - Nx.take(tensor, idx, axis: opts[:axis]) + + if opts[:independent] do + # We need to use take_along_axis in the independent case because + # the sort_keys tensor has the same shape as the input tensor. + Nx.take_along_axis(tensor, idx, axis: opts[:axis]) + else + # In the non-independent case, we use take because the sort_keys + # tensor is a 1D tensor. + Nx.take(tensor, idx, axis: opts[:axis]) + end end @choice_options """ From 08dc89336c5b15c42c2132b44f5aa28f93fd13ee Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 30 Oct 2024 01:51:15 -0300 Subject: [PATCH 4/6] fix: remove compare_type instead of setting notype --- exla/lib/exla/mlir/value.ex | 11 ++++------- nx/lib/nx/random.ex | 25 +++++-------------------- 2 files changed, 9 insertions(+), 27 deletions(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index e38d09fc0b..8bc3195983 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -66,19 +66,16 @@ defmodule EXLA.MLIR.Value do comparison_type = cond do Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) -> - attr_comparison_type(:float) + [compare_type: attr_comparison_type(:float)] Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) -> - attr_comparison_type(:float) + [compare_type: attr_comparison_type(:float)] true -> - attr_comparison_type(:notype) + [] end - attributes = [ - comparison_direction: attr_comparison_direction(direction), - compare_type: comparison_type - ] + attributes = [comparison_direction: attr_comparison_direction(direction)] ++ comparison_type result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})]) diff --git a/nx/lib/nx/random.ex b/nx/lib/nx/random.ex index e4f6c1e189..5b429517cb 100644 --- a/nx/lib/nx/random.ex +++ b/nx/lib/nx/random.ex @@ -799,21 +799,15 @@ defmodule Nx.Random do axis = opts[:axis] if opts[:independent] do - shuffle_independent(key, tensor, axis: axis, independent: true) + shuffle_independent(key, tensor, axis: axis) else - {idx, key} = - shuffle_independent(key, Nx.iota({Nx.axis_size(tensor, axis)}), - axis: 0, - independent: false - ) - + {idx, key} = shuffle_independent(key, Nx.iota({Nx.axis_size(tensor, axis)}), axis: 0) {Nx.take(tensor, idx, axis: axis), key} end end defnp shuffle_independent(key, tensor, opts) do axis = opts[:axis] - independent = opts[:independent] # reference: https://github.com/google/jax/blob/838bc454895ed2086563301936fb0d6d852fd198/jax/_src/random.py#L437 exponent = 3 @@ -827,25 +821,16 @@ defmodule Nx.Random do while {i = 0, tensor, key}, i < num_rounds do keys = split(key) sort_keys = random_bits(keys[1], shape: tensor.shape) - tensor = sort_key_val(tensor, sort_keys, axis: axis, independent: independent) + tensor = sort_key_val(tensor, sort_keys, axis: axis) {i + 1, tensor, keys[0]} end {out, key} end - deftransformp sort_key_val(tensor, sort_keys, opts \\ []) do + defnp sort_key_val(tensor, sort_keys, opts \\ []) do idx = Nx.argsort(sort_keys, axis: opts[:axis]) - - if opts[:independent] do - # We need to use take_along_axis in the independent case because - # the sort_keys tensor has the same shape as the input tensor. - Nx.take_along_axis(tensor, idx, axis: opts[:axis]) - else - # In the non-independent case, we use take because the sort_keys - # tensor is a 1D tensor. - Nx.take(tensor, idx, axis: opts[:axis]) - end + Nx.take_along_axis(tensor, idx, axis: opts[:axis]) end @choice_options """ From 80d16be39760198f6b21798640638b73e11f1d70 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 30 Oct 2024 02:17:30 -0300 Subject: [PATCH 5/6] feat: port fixes over from metal plugin branch Co-Authored-By: Jonatan Klosko --- exla/lib/exla/defn.ex | 38 ++++++++++++++++++++++++------------- exla/lib/exla/mlir/value.ex | 22 ++++++++++++++++----- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 5e8b94c5ab..d2f2fd7357 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1367,30 +1367,42 @@ defmodule EXLA.Defn do ## Computation helpers - defp sort_computation(op, type, arg_typespecs, %{builder: %EXLA.MLIR.Function{} = function}) do + defp sort_computation(operator, type, arg_typespecs, %{ + builder: %EXLA.MLIR.Function{} = function + }) do {region, [lhs, rhs | _]} = Function.push_region(function, arg_typespecs) typespec = Typespec.tensor({:pred, 8}, {}) - op = - cond do - Nx.Type.integer?(type) -> - apply(Value, op, [lhs, rhs, typespec]) - - op == :less -> - is_nan = Value.is_nan(rhs, typespec) - Value.bitwise_or(is_nan, Value.less(lhs, rhs, typespec), typespec) - - op == :greater -> - is_nan = Value.is_nan(lhs, typespec) - Value.bitwise_or(is_nan, Value.greater(lhs, rhs, typespec), typespec) + {lhs, rhs} = + if Nx.Type.integer?(type) do + {lhs, rhs} + else + {sort_computation_canonicalize_float(lhs), sort_computation_canonicalize_float(rhs)} end + op = apply(Value, operator, [lhs, rhs, typespec, [total_order: true]]) + Value.return(function, [op]) Function.pop_region(function) region end + defp sort_computation_canonicalize_float(%Value{function: func} = op) do + # Standardize the representation of NaNs (-NaN, NaN) and zeros (-0, 0). + # See https://github.com/google/jax/blob/e81c82605f0e1813080cfe1037d043b27b38291d/jax/_src/lax/lax.py#L4248-L4253 + + op_typespec = Value.get_typespec(op) + + zero = Value.constant(func, [0], Typespec.to_shape(op_typespec, {})) + zeros = Value.constant(func, [0], op_typespec) + nans = Value.constant(func, [:nan], op_typespec) + + pred_typespec = Typespec.tensor({:pred, 8}, {}) + op = Value.select(Value.equal(op, zero, pred_typespec), zeros, op, op_typespec) + Value.select(Value.is_nan(op, pred_typespec), nans, op, op_typespec) + end + defp op_computation( op, arg_typespecs, diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 8bc3195983..2b25c6f8f6 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -54,12 +54,17 @@ defmodule EXLA.MLIR.Value do } for {op, direction} <- @bin_comparison_ops do - def unquote(op)(%Value{function: func} = lhs, %Value{function: func} = rhs, typespec) do - compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction)) + def unquote(op)( + %Value{function: func} = lhs, + %Value{function: func} = rhs, + typespec, + opts \\ [] + ) do + compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction), opts[:total_order]) end end - defp compare_and_return_bool(func, lhs, rhs, typespec, direction) do + defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do %{type: lhs_type} = get_typespec(lhs) %{type: rhs_type} = get_typespec(rhs) @@ -69,7 +74,14 @@ defmodule EXLA.MLIR.Value do [compare_type: attr_comparison_type(:float)] Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) -> - [compare_type: attr_comparison_type(:float)] + attr = + if total_order? do + attr_comparison_type(:totalorder) + else + attr_comparison_type(:float) + end + + [compare_type: attr] true -> [] @@ -1069,7 +1081,7 @@ defmodule EXLA.MLIR.Value do defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne], do: attr_enum("stablehlo", "comparison_direction", value) - defp attr_comparison_type(value) when value in [:float, :totalorder, :notype], + defp attr_comparison_type(value) when value in [:float, :totalorder], do: attr_enum("stablehlo", "comparison_type", value) defp attr_precision(value) when value in [:default, :high, :highest], From 451a27ff3aa6b135111117d6c54a495f157b1f87 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 30 Oct 2024 02:21:49 -0300 Subject: [PATCH 6/6] chore: revert test file --- exla/test/exla/random_test.exs | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/exla/test/exla/random_test.exs b/exla/test/exla/random_test.exs index 67aa50c5f8..3b86fc64b9 100644 --- a/exla/test/exla/random_test.exs +++ b/exla/test/exla/random_test.exs @@ -39,23 +39,4 @@ defmodule EXLA.NxRandomTest do ) end end - - @tag :cuda_required - test "regression on single-dimensional and multi-dimensional Random.shuffle" do - # these are put in the process dictionary, so it's thread-safe to do this - Nx.default_backend({EXLA.Backend, client: :cuda}) - Nx.Defn.default_options(compiler: EXLA, client: :cuda) - key = Nx.Random.key(127) - - t1 = Nx.iota({2, 100}) - t2 = Nx.iota({100}) - - {t1_shuffled_0, key} = Nx.Random.shuffle(key, t1, axis: 0) - {t1_shuffled_1, key} = Nx.Random.shuffle(key, t1, axis: 1) - {t2_shuffled, _key} = Nx.Random.shuffle(key, t2) - - assert_equal(Nx.sort(t1_shuffled_0, axis: 0), t1) - assert_equal(Nx.sort(t1_shuffled_1, axis: 1), t1) - assert_equal(Nx.sort(t2_shuffled), t2) - end end