From 0f62077bdc9b1116aa64dc5a4dc1ac4530d4ddde Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 9 Nov 2023 19:22:34 -0300 Subject: [PATCH 1/6] fix: grad for indexed ops with axes --- nx/lib/nx/binary_backend.ex | 10 +++ nx/lib/nx/defn/grad.ex | 16 +++- nx/test/nx/defn/grad_test.exs | 134 ++++++++++++++++++++++++---------- 3 files changed, 118 insertions(+), 42 deletions(-) diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index b595cfdb1a..38987bf270 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -2097,10 +2097,20 @@ defmodule Nx.BinaryBackend do inverse_permutation = permutation + |> Enum.filter(&(&1 in Nx.axes(out))) |> Enum.with_index() |> Enum.sort_by(fn {x, _} -> x end) |> Enum.map(fn {_, i} -> i end) + diff = Nx.rank(out) - length(inverse_permutation) + + inverse_permutation = + if diff > 0 do + Enum.to_list(0..(diff - 1)) ++ Enum.map(inverse_permutation, &(&1 + diff)) + else + inverse_permutation + end + { &Nx.transpose(&1, axes: permutation), &Nx.transpose(&1, axes: inverse_permutation) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 3c32c5f14a..57ffc26c25 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -806,11 +806,19 @@ defmodule Nx.Defn.Grad do end defp grad(:gather, [t, i, opts], _ans, g) do - rank = Nx.rank(t) - num_elements = i.shape |> Tuple.product() |> div(rank) + leading_i_shape = i.shape |> Tuple.delete_at(tuple_size(i.shape) - 1) + num_elements = Tuple.product(leading_i_shape) - indices = Nx.reshape(i, {num_elements, rank}) - updates = Nx.reshape(g, {num_elements}) + indices = Nx.reshape(i, {num_elements, :auto}) + + num_axes = length(opts[:axes]) + + updates_shape = + Enum.reduce(Enum.sort(opts[:axes], :desc), t.shape, fn axis, shape -> + Tuple.delete_at(shape, axis) + end) + + updates = Nx.reshape(g, Tuple.insert_at(updates_shape, 0, num_elements)) g = t |> Expr.broadcast(0, t.shape, Nx.axes(t)) |> Nx.indexed_add(indices, updates, opts) [{t, g}] diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index d2ae7d5f43..e5a8df50e3 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -3355,13 +3355,13 @@ defmodule Nx.Defn.GradTest do ) end - defn grad_sum_log_power_gather_cos(t, i) do + defn grad_sum_log_power_gather_cos(t, i, opts \\ []) do grad( t, fn t -> t |> Nx.cos() - |> Nx.gather(i) + |> Nx.gather(i, opts) |> Nx.pow(2) |> Nx.log() |> Nx.sum() @@ -3427,26 +3427,34 @@ defmodule Nx.Defn.GradTest do ]) ) - assert Nx.tensor([ - [-0.0, -9.34444522857666, 4.370079040527344, -0.0], - [-2.3156425952911377, 6.7610297203063965, 0.0, -0.0], - [13.5994234085083, -0.0, -1.2967215776443481, 1355.705078125] - ]) == - grad_sum_log_power_gather_cos( - Nx.tensor([ - [0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11] - ]), - Nx.tensor([ - [ - [[0, 0], [0, 1], [0, 2]], - [[2, 0], [1, 0], [0, 1]], - [[0, 1], [1, 1], [2, 2]], - [[2, 3], [2, 3], [2, 3]] - ] - ]) - ) + t = + Nx.tensor([ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11] + ]) + + i = + Nx.tensor([ + [ + [[0, 0], [0, 1], [0, 2]], + [[2, 0], [1, 0], [0, 1]], + [[0, 1], [1, 1], [2, 2]], + [[2, 3], [2, 3], [2, 3]] + ] + ]) + + result = + Nx.tensor([ + [-0.0, -9.34444522857666, 4.370079040527344, -0.0], + [-2.3156425952911377, 6.7610297203063965, 0.0, -0.0], + [13.5994234085083, -0.0, -1.2967215776443481, 1355.705078125] + ]) + + assert result == grad_sum_log_power_gather_cos(t, i) + + assert Nx.new_axis(result, 1) == + grad_sum_log_power_gather_cos(Nx.new_axis(t, 1), i, axes: [0, 2]) end end @@ -3555,11 +3563,11 @@ defmodule Nx.Defn.GradTest do describe "indexed_put" do defn grad_indexed_put_target(t, i, u), do: grad(t, &Nx.sum(Nx.indexed_put(&1, i, u))) - defn grad_indexed_put_target_composite(t, i, u) do + defn grad_indexed_put_target_composite(t, i, u, opts \\ []) do grad(t, fn t -> t |> Nx.cos() - |> Nx.indexed_put(i, u) + |> Nx.indexed_put(i, u, opts) |> Nx.sin() |> Nx.sum() end) @@ -3567,10 +3575,10 @@ defmodule Nx.Defn.GradTest do defn grad_indexed_put_updates(t, i, u), do: grad(u, &Nx.sum(Nx.indexed_put(t, i, &1))) - defn grad_indexed_put_updates_composite(t, i, u) do + defn grad_indexed_put_updates_composite(t, i, u, opts \\ []) do grad(u, fn u -> t - |> Nx.indexed_put(i, Nx.cos(u)) + |> Nx.indexed_put(i, Nx.cos(u), opts) |> Nx.sin() |> Nx.sum() end) @@ -3578,19 +3586,19 @@ defmodule Nx.Defn.GradTest do defn grad_indexed_put_indices(t, i, u), do: grad(i, &Nx.sum(Nx.indexed_put(t, &1, u))) - defn grad_indexed_put_indices_composite(t, i, u) do + defn grad_indexed_put_indices_composite(t, i, u, opts \\ []) do grad(i, fn i -> t - |> Nx.indexed_put(Nx.multiply(i, 2), u) + |> Nx.indexed_put(Nx.multiply(i, 2), u, opts) |> Nx.sin() |> Nx.sum() end) end - defn grad_indexed_put_simultaneous_composite(t, i) do + defn grad_indexed_put_simultaneous_composite(t, i, opts \\ []) do grad(t, fn t -> t - |> Nx.indexed_put(i, Nx.cos(t)) + |> Nx.indexed_put(i, Nx.cos(t), opts) |> Nx.sin() |> Nx.sum() end) @@ -3618,6 +3626,16 @@ defmodule Nx.Defn.GradTest do ]), grad_indexed_put_target_composite(t, i, u) ) + + assert_all_close( + Nx.tensor([ + [0, 0, -0.8316, -0.0774], + [0, 0.9206, 0.1602, -0.4789], + [-0.9789, -0.2525, 0, 0] + ]) + |> Nx.new_axis(1), + grad_indexed_put_target_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2]) + ) end test "grad wrt to source" do @@ -3631,6 +3649,11 @@ defmodule Nx.Defn.GradTest do # f'(x) = cos(cos(x)) * (-sin(x)) expected = u |> Nx.cos() |> Nx.cos() |> Nx.multiply(Nx.sin(u)) |> Nx.negate() assert_all_close(expected, grad_indexed_put_updates_composite(t, i, u)) + + assert_all_close( + Nx.new_axis(expected, 1), + grad_indexed_put_updates_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2]) + ) end test "grad wrt to indices" do @@ -3640,6 +3663,11 @@ defmodule Nx.Defn.GradTest do assert_all_close(Nx.broadcast(0, i), grad_indexed_put_indices(t, i, u)) assert_all_close(Nx.broadcast(0, i), grad_indexed_put_indices_composite(t, i, u)) + + assert_all_close( + Nx.broadcast(0, i), + grad_indexed_put_indices_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2]) + ) end test "grad wrt to both source and target simultaneously" do @@ -3653,6 +3681,11 @@ defmodule Nx.Defn.GradTest do expected = t |> Nx.cos() |> Nx.cos() |> Nx.multiply(Nx.sin(t)) |> Nx.negate() assert_all_close(expected, grad_indexed_put_simultaneous_composite(t, i)) + + assert_all_close( + Nx.new_axis(expected, 1), + grad_indexed_put_simultaneous_composite(Nx.new_axis(t, 1), i, axes: [0]) + ) end end @@ -3666,11 +3699,11 @@ defmodule Nx.Defn.GradTest do end) end - defn grad_indexed_add_target_composite(t, i, u) do + defn grad_indexed_add_target_composite(t, i, u, opts \\ []) do grad(t, fn t -> t |> Nx.cos() - |> Nx.indexed_add(i, u) + |> Nx.indexed_add(i, u, opts) |> Nx.sin() |> Nx.sum() end) @@ -3685,10 +3718,10 @@ defmodule Nx.Defn.GradTest do end) end - defn grad_indexed_add_updates_composite(t, i, u) do + defn grad_indexed_add_updates_composite(t, i, u, opts \\ []) do grad(u, fn u -> t - |> Nx.indexed_add(i, Nx.cos(u)) + |> Nx.indexed_add(i, Nx.cos(u), opts) |> Nx.sin() |> Nx.sum() end) @@ -3696,19 +3729,19 @@ defmodule Nx.Defn.GradTest do defn grad_indexed_add_indices(t, i, u), do: grad(i, &Nx.sum(Nx.indexed_add(t, &1, u))) - defn grad_indexed_add_indices_composite(t, i, u) do + defn grad_indexed_add_indices_composite(t, i, u, opts \\ []) do grad(i, fn i -> t - |> Nx.indexed_add(Nx.multiply(i, 2), u) + |> Nx.indexed_add(Nx.multiply(i, 2), u, opts) |> Nx.sin() |> Nx.sum() end) end - defn grad_indexed_add_simultaneous_composite(t, i) do + defn grad_indexed_add_simultaneous_composite(t, i, opts \\ []) do grad(t, fn t -> t - |> Nx.indexed_add(i, Nx.cos(t)) + |> Nx.indexed_add(i, Nx.cos(t), opts) |> Nx.sin() |> Nx.sum() end) @@ -3739,6 +3772,16 @@ defmodule Nx.Defn.GradTest do ]), grad_indexed_add_target_composite(t, i, u) ) + + assert_all_close( + Nx.tensor([ + [0, -0.0932, -0.8316, -0.0774], + [0.1684, 0.9206, 0.1602, -0.4789], + [-0.9789, -0.2525, -0.1442, -0.9905] + ]) + |> Nx.new_axis(1), + grad_indexed_add_target_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2]) + ) end test "grad wrt to source" do @@ -3756,6 +3799,11 @@ defmodule Nx.Defn.GradTest do expected = cosx_tn |> Nx.cos() |> Nx.multiply(Nx.sin(u)) |> Nx.negate() assert_all_close(expected, grad_indexed_add_updates_composite(t, i, u)) + + assert_all_close( + Nx.new_axis(expected, 1), + grad_indexed_add_updates_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2]) + ) end test "grad wrt to indices" do @@ -3765,6 +3813,11 @@ defmodule Nx.Defn.GradTest do assert_all_close(Nx.broadcast(0, i), grad_indexed_add_indices(t, i, u)) assert_all_close(Nx.broadcast(0, i), grad_indexed_add_indices_composite(t, i, u)) + + assert_all_close( + Nx.broadcast(0, i), + grad_indexed_add_indices_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2]) + ) end test "grad wrt to both source and target simultaneously" do @@ -3779,6 +3832,11 @@ defmodule Nx.Defn.GradTest do expected = cosx_tn |> Nx.cos() |> Nx.multiply(Nx.subtract(1, Nx.sin(t))) assert_all_close(expected, grad_indexed_add_simultaneous_composite(t, i)) + + assert_all_close( + Nx.new_axis(expected, 1), + grad_indexed_add_simultaneous_composite(Nx.new_axis(t, 1), i, axes: [0]) + ) end end From 65572940c1c0d14cfb614e3c9a573b22698de9e3 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 9 Nov 2023 19:24:22 -0300 Subject: [PATCH 2/6] chore: resolve warnings --- nx/lib/nx/defn/grad.ex | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 57ffc26c25..24e4294a4b 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -806,13 +806,11 @@ defmodule Nx.Defn.Grad do end defp grad(:gather, [t, i, opts], _ans, g) do - leading_i_shape = i.shape |> Tuple.delete_at(tuple_size(i.shape) - 1) + leading_i_shape = Tuple.delete_at(i.shape, tuple_size(i.shape) - 1) num_elements = Tuple.product(leading_i_shape) indices = Nx.reshape(i, {num_elements, :auto}) - num_axes = length(opts[:axes]) - updates_shape = Enum.reduce(Enum.sort(opts[:axes], :desc), t.shape, fn axis, shape -> Tuple.delete_at(shape, axis) From 6e7bea1f3f1ada7809e8d579f633f02828a4024b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 11 Nov 2023 17:33:01 +0100 Subject: [PATCH 3/6] Improve grad --- nx/lib/nx/defn/grad.ex | 20 +++++++++++--------- nx/lib/nx/shape.ex | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 24e4294a4b..20dcab161c 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -806,19 +806,21 @@ defmodule Nx.Defn.Grad do end defp grad(:gather, [t, i, opts], _ans, g) do - leading_i_shape = Tuple.delete_at(i.shape, tuple_size(i.shape) - 1) - num_elements = Tuple.product(leading_i_shape) + i_axes = opts[:axes] + i_shape = i.shape + t_shape = t.shape - indices = Nx.reshape(i, {num_elements, :auto}) + num_elements = Tuple.product(i_shape) |> div(elem(i_shape, tuple_size(i_shape) - 1)) + updates_shape = for i <- Nx.axes(t), i not in i_axes, do: elem(t_shape, i) - updates_shape = - Enum.reduce(Enum.sort(opts[:axes], :desc), t.shape, fn axis, shape -> - Tuple.delete_at(shape, axis) - end) + indices = Nx.reshape(i, {num_elements, :auto}) + updates = Nx.reshape(g, List.to_tuple([num_elements | updates_shape])) - updates = Nx.reshape(g, Tuple.insert_at(updates_shape, 0, num_elements)) + g = + 0 + |> Nx.broadcast(t_shape) + |> Nx.indexed_add(indices, updates, opts) - g = t |> Expr.broadcast(0, t.shape, Nx.axes(t)) |> Nx.indexed_add(indices, updates, opts) [{t, g}] end diff --git a/nx/lib/nx/shape.ex b/nx/lib/nx/shape.ex index a55f76054a..db097ca7da 100644 --- a/nx/lib/nx/shape.ex +++ b/nx/lib/nx/shape.ex @@ -1641,7 +1641,7 @@ defmodule Nx.Shape do "expected the last indices dimension size (#{last_size}) to be less than or equal to the tensor rank (#{rank})" end - inner_shape = for i <- 0..(rank - 1), i not in axes, do: elem(shape, i) + inner_shape = for i <- Nx.axes(shape), i not in axes, do: elem(shape, i) shape = List.to_tuple(outer_shape ++ inner_shape) names = List.duplicate(nil, tuple_size(shape)) {shape, names} From e9fdbb53c46e4923270067f106f6975032e7b6e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 11 Nov 2023 20:56:52 +0100 Subject: [PATCH 4/6] More fixes --- exla/lib/exla/defn.ex | 13 +++--- nx/lib/nx.ex | 92 +++++++++++++++---------------------- nx/lib/nx/binary_backend.ex | 34 ++------------ 3 files changed, 47 insertions(+), 92 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index f656100245..6994f6e9b0 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1517,22 +1517,21 @@ defmodule EXLA.Defn do Value.gather(tensor, indices, slice_sizes, offset_dims, axes, axes, index_vector_dim) end - defp to_operator(:gather, [tensor, indices, opts], ans, _state) do + defp to_operator(:gather, [tensor, indices, opts], _ans, _state) do axes = Keyword.fetch!(opts, :axes) tensor_shape = op_shape(tensor) tensor_rank = tuple_size(tensor_shape) + tensor_axes = axes_for_rank(tensor_rank) index_vector_dim = tuple_size(op_shape(indices)) - 1 slice_sizes = - for i <- 0..(tensor_rank - 1) do + for i <- tensor_axes do if i in axes, do: 1, else: elem(tensor_shape, i) end - offset_dims = axes_for_rank(tensor_rank) -- axes - - tensor - |> EXLA.Op.gather(indices, index_vector_dim, slice_sizes, offset_dims, axes, axes) - |> EXLA.Op.reshape(ans.shape) + batch_size = tensor_rank - length(axes) + offset_dims = count_up(batch_size, batch_size) + EXLA.Op.gather(tensor, indices, index_vector_dim, slice_sizes, offset_dims, axes, axes) end defp to_operator(:reverse, [%Value{} = tensor, axes], _ans, _state) do diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index fc84a0c539..428ec7e5f4 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -12663,48 +12663,6 @@ defmodule Nx do ] > - iex> Nx.transpose(Nx.iota({2, 3, 4}, names: [:batch, :x, :y]), axes: [:y, :batch, :x]) - #Nx.Tensor< - s64[y: 4][batch: 2][x: 3] - [ - [ - [0, 4, 8], - [12, 16, 20] - ], - [ - [1, 5, 9], - [13, 17, 21] - ], - [ - [2, 6, 10], - [14, 18, 22] - ], - [ - [3, 7, 11], - [15, 19, 23] - ] - ] - > - - iex> Nx.transpose(Nx.iota({2, 3, 4}, names: [:batch, :x, :y]), axes: [:batch, :y, :x]) - #Nx.Tensor< - s64[batch: 2][y: 4][x: 3] - [ - [ - [0, 4, 8], - [1, 5, 9], - [2, 6, 10], - [3, 7, 11] - ], - [ - [12, 16, 20], - [13, 17, 21], - [14, 18, 22], - [15, 19, 23] - ] - ] - > - ### Vectorized tensors For vectorized tensors, transpose will manipulate the inner shape only, @@ -14331,31 +14289,53 @@ defmodule Nx do ### Gathering subsets iex> t = Nx.tensor([[1, 2, 3], [3, 4, 5]]) - iex> Nx.gather(t, Nx.tensor([[1], [0], [1]])) + iex> Nx.gather(t, Nx.tensor([[1], [0]])) #Nx.Tensor< - s64[3][3] + s64[2][3] [ [3, 4, 5], - [1, 2, 3], - [3, 4, 5] + [1, 2, 3] ] > The `:axes` option controls which dimensions the indexes point to, - this can be useful, for example, to access columns instead of rows: + this can be useful, for example, to access columns instead of rows. + Note can also access the same index several times: - iex> t = Nx.tensor([[[1, 2, 3]], [[4, 5, 6]]]) - iex> Nx.gather(t, Nx.tensor([[1], [0]]), axes: [2]) + iex> t = Nx.tensor([[1, 2, 3], [4, 5, 6]]) + iex> Nx.gather(t, Nx.tensor([[1], [0], [2], [1]]), axes: [1]) #Nx.Tensor< - s64[2][2][1] + s64[4][2] + [ + [2, 5], + [1, 4], + [3, 6], + [2, 5] + ] + > + + The overall output shape will have the format of the indices shape + (except the last element) followed by all non-indexed dimensions of + the tensor. Here is a more complex example: + + iex> t = Nx.iota({2, 1, 3}) + iex> Nx.gather(t, Nx.tensor([[[1], [0], [2]]]), axes: [2]) + #Nx.Tensor< + s64[1][3][2][1] [ [ - [2], - [1] - ], - [ - [5], - [4] + [ + [1], + [4] + ], + [ + [0], + [3] + ], + [ + [2], + [5] + ] ] ] > diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 38987bf270..a826760db7 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -2091,38 +2091,14 @@ defmodule Nx.BinaryBackend do axes = opts[:axes] tensor_axes = Nx.axes(tensor) - {permutation_fn, inverse_permutation_fn} = - if axes && axes != tensor_axes do - permutation = axes ++ (tensor_axes -- axes) - - inverse_permutation = - permutation - |> Enum.filter(&(&1 in Nx.axes(out))) - |> Enum.with_index() - |> Enum.sort_by(fn {x, _} -> x end) - |> Enum.map(fn {_, i} -> i end) - - diff = Nx.rank(out) - length(inverse_permutation) - - inverse_permutation = - if diff > 0 do - Enum.to_list(0..(diff - 1)) ++ Enum.map(inverse_permutation, &(&1 + diff)) - else - inverse_permutation - end - - { - &Nx.transpose(&1, axes: permutation), - &Nx.transpose(&1, axes: inverse_permutation) - } + tensor = + if List.starts_with?(tensor_axes, axes) do + tensor else - {& &1, & &1} + Nx.transpose(tensor, axes: axes ++ (tensor_axes -- axes)) end - out - |> gather(permutation_fn.(tensor), indices) - |> then(inverse_permutation_fn) - |> Nx.reshape(out.shape) + gather(out, tensor, indices) end defp gather(out, tensor, indices) do From 0e485f9e95f5b6dad379f36160157170d90b1080 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 11 Nov 2023 21:05:54 +0100 Subject: [PATCH 5/6] Torch --- torchx/lib/torchx/backend.ex | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 0d06086f30..e12747de03 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -435,35 +435,20 @@ defmodule Torchx.Backend do @impl true def gather(out, tensor, indices, opts) do tensor_axes = Nx.axes(tensor) - axes = opts[:axes] - {permutation_fn, inverse_permutation_fn} = - if axes && axes != tensor_axes do - permutation = axes ++ (tensor_axes -- axes) - - inverse_permutation = - permutation - |> Enum.with_index() - |> Enum.sort_by(fn {x, _} -> x end) - |> Enum.map(fn {_, i} -> i end) - - { - &Torchx.permute(&1, permutation), - &Torchx.permute(&1, inverse_permutation) - } + tensor = + if is_nil(axes) or List.starts_with?(tensor_axes, axes) do + tensor else - {& &1, & &1} + Nx.transpose(tensor, axes: axes ++ (tensor_axes -- axes)) end {tensor_tx, indices_tx} = indices_from_nx(tensor, indices) tensor_tx - |> then(permutation_fn) |> Torchx.index(indices_tx) |> Torchx.reshape(out.shape) - |> then(inverse_permutation_fn) - |> Torchx.reshape(out.shape) |> to_nx(out) end From 2234e3441c121637044e509f309dee8df4fb83bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 11 Nov 2023 21:24:36 +0100 Subject: [PATCH 6/6] fixes --- nx/lib/nx/binary_backend.ex | 219 +++++++++++++++++------------------- 1 file changed, 102 insertions(+), 117 deletions(-) diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index a826760db7..89436c528f 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -1761,129 +1761,92 @@ defmodule Nx.BinaryBackend do end defp indexed_op(out, target, indices, updates, opts, resolve_updates, update_element) do - axes = opts[:axes] - target_axes = Nx.axes(target) - - {permutation_fn, inverse_permutation_fn, out_shape} = - if axes && axes != target_axes do - permutation = axes ++ (target_axes -- axes) - - inverse_permutation = - permutation - |> Enum.with_index() - |> Enum.sort_by(fn {x, _} -> x end) - |> Enum.map(fn {_, i} -> i end) - - transposed_out_shape = Enum.map(permutation, &elem(out.shape, &1)) |> List.to_tuple() - - { - &Nx.transpose(&1, axes: permutation), - &Nx.transpose(&1, axes: inverse_permutation), - transposed_out_shape - } - else - {& &1, & &1, out.shape} - end - - %{out | shape: out_shape} - |> indexed_op_on_first_dims( - permutation_fn.(target), - indices, - updates, - resolve_updates, - update_element - ) - |> then(inverse_permutation_fn) - end - - defp indexed_op_on_first_dims( - %T{} = out, - %T{shape: shape, type: {_, target_size}} = target, - %T{shape: indices_shape} = indices, - %T{shape: updates_shape, type: {_, updates_size}} = updates, - resolve_updates, - update_element - ) do - indices_bin_list = - indices - |> to_binary() - |> aggregate_axes([1], indices_shape, elem(indices.type, 1)) - - updates_binary = to_binary(updates) - updates_count = updates_shape |> Tuple.product() |> div(elem(updates_shape, 0)) - updates_chunk = updates_count * updates_size - target_chunk = updates_count * target_size - - updates_list = - for <> do - binary_to_list(x, updates.type) - end + with_permutation(target, out, opts, fn target, out -> + %T{shape: shape, type: {_, target_size}} = target + %T{shape: indices_shape} = indices + %T{shape: updates_shape, type: {_, updates_size}} = updates + + indices_bin_list = + indices + |> to_binary() + |> aggregate_axes([1], indices_shape, elem(indices.type, 1)) + + updates_binary = to_binary(updates) + updates_count = updates_shape |> Tuple.product() |> div(elem(updates_shape, 0)) + updates_chunk = updates_count * updates_size + target_chunk = updates_count * target_size + + updates_list = + for <> do + binary_to_list(x, updates.type) + end - offsets_list = - for idx_bin <- indices_bin_list do - idx = binary_to_list(idx_bin, indices.type) - offset = index_to_binary_offset(idx, shape) - offset * target_size - end + offsets_list = + for idx_bin <- indices_bin_list do + idx = binary_to_list(idx_bin, indices.type) + offset = index_to_binary_offset(idx, shape) + offset * target_size + end - {offsets_with_updates, _last_offset} = - offsets_list - |> Enum.zip(updates_list) - |> Enum.group_by(fn {off, _} -> off end, fn {_, upd} -> upd end) - |> Enum.sort_by(fn {off, _} -> off end) - |> Enum.map_reduce(0, fn {next_offset, upds}, previous_offset -> - {{ - previous_offset + target_chunk, - next_offset, - resolve_updates.(upds) - }, next_offset} - end) + {offsets_with_updates, _last_offset} = + offsets_list + |> Enum.zip(updates_list) + |> Enum.group_by(fn {off, _} -> off end, fn {_, upd} -> upd end) + |> Enum.sort_by(fn {off, _} -> off end) + |> Enum.map_reduce(0, fn {next_offset, upds}, previous_offset -> + {{ + previous_offset + target_chunk, + next_offset, + resolve_updates.(upds) + }, next_offset} + end) - target_binary = to_binary(target) + target_binary = to_binary(target) - offsets_with_updates = - List.update_at(offsets_with_updates, 0, fn {_prev, current, update} -> - {0, current, update} - end) + offsets_with_updates = + List.update_at(offsets_with_updates, 0, fn {_prev, current, update} -> + {0, current, update} + end) - {result, tail} = - for {previous, current, updates} <- offsets_with_updates, reduce: {<<>>, target_binary} do - {traversed, to_traverse} -> - before_slice_size = current - previous + {result, tail} = + for {previous, current, updates} <- offsets_with_updates, reduce: {<<>>, target_binary} do + {traversed, to_traverse} -> + before_slice_size = current - previous - <> = to_traverse + <> = to_traverse - updated_elements = - match_types [target.type, out.type] do - current = for <>, do: read!(element, 0) + updated_elements = + match_types [target.type, out.type] do + current = for <>, do: read!(element, 0) - Enum.zip_with(current, updates, fn left, right -> - <> - end) - end + Enum.zip_with(current, updates, fn left, right -> + <> + end) + end - # this can be a list of binaries because we are accumulation an iodata list - before_offset = - if target.type == out.type do - before_offset - else - binary_to_binary(before_offset, target.type, out.type, & &1) - end + # this can be a list of binaries because we are accumulation an iodata list + before_offset = + if target.type == out.type do + before_offset + else + binary_to_binary(before_offset, target.type, out.type, & &1) + end - {[traversed, before_offset | updated_elements], to_traverse} - end + {[traversed, before_offset | updated_elements], to_traverse} + end - # this can be a list of binaries because we are accumulation an iodata list - tail = - if target.type == out.type do - tail - else - binary_to_binary(tail, target.type, out.type, & &1) - end + # this can be a list of binaries because we are accumulation an iodata list + tail = + if target.type == out.type do + tail + else + binary_to_binary(tail, target.type, out.type, & &1) + end - from_binary(out, IO.iodata_to_binary([result, tail])) + from_binary(out, IO.iodata_to_binary([result, tail])) + end) end @impl true @@ -2057,12 +2020,7 @@ defmodule Nx.BinaryBackend do |> List.delete(axis) |> List.insert_at(Nx.rank(tensor) - 1, axis) - inverse_permutation = - permutation - |> Enum.with_index() - |> Enum.sort_by(fn {x, _} -> x end) - |> Enum.map(fn {_, i} -> i end) - + inverse_permutation = inverse_permutation(permutation) shape_list = Tuple.to_list(output.shape) permuted_shape = permutation |> Enum.map(&Enum.at(shape_list, &1)) |> List.to_tuple() @@ -2597,6 +2555,33 @@ defmodule Nx.BinaryBackend do end end + ## Permutation helpers + + defp with_permutation(target, out, opts, fun) do + axes = opts[:axes] + target_axes = Nx.axes(target) + + if is_nil(axes) or List.starts_with?(target_axes, axes) do + fun.(target, out) + else + permutation = axes ++ (target_axes -- axes) + inverse_permutation = inverse_permutation(permutation) + out_shape = Enum.map(permutation, &elem(out.shape, &1)) |> List.to_tuple() + + target + |> Nx.transpose(axes: permutation) + |> fun.(%{out | shape: out_shape}) + |> Nx.transpose(axes: inverse_permutation) + end + end + + defp inverse_permutation(permutation) do + permutation + |> Enum.with_index() + |> Enum.sort_by(fn {x, _} -> x end) + |> Enum.map(fn {_, i} -> i end) + end + ## Aggregation helpers defp aggregate_axes(binary, axes, shape, size) do