Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: grad for indexed ops with axes #1360

Merged
merged 6 commits into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2097,10 +2097,20 @@ defmodule Nx.BinaryBackend do

inverse_permutation =
permutation
|> Enum.filter(&(&1 in Nx.axes(out)))
|> Enum.with_index()
|> Enum.sort_by(fn {x, _} -> x end)
|> Enum.map(fn {_, i} -> i end)

diff = Nx.rank(out) - length(inverse_permutation)

inverse_permutation =
if diff > 0 do
Enum.to_list(0..(diff - 1)) ++ Enum.map(inverse_permutation, &(&1 + diff))
else
inverse_permutation
end
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally sure if this is the correct fix. Basically the test for grad(gather) was failing and this fixed it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we ever hit the else branch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we can. Keeping only the do branch works if we change the range to 0..(diff-1)//1 because it seems that for some cases the diff is 0. I think those are cases where the index tensor as many axes as the input tensor.

The Enum.filter(... in Nx.axes(out)) takes care of the case where the index tensor has fewer axes, and the do block takes care of when there are more axes than in the input.


{
&Nx.transpose(&1, axes: permutation),
&Nx.transpose(&1, axes: inverse_permutation)
Expand Down
14 changes: 10 additions & 4 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -806,11 +806,17 @@ defmodule Nx.Defn.Grad do
end

defp grad(:gather, [t, i, opts], _ans, g) do
rank = Nx.rank(t)
num_elements = i.shape |> Tuple.product() |> div(rank)
leading_i_shape = Tuple.delete_at(i.shape, tuple_size(i.shape) - 1)
num_elements = Tuple.product(leading_i_shape)

indices = Nx.reshape(i, {num_elements, rank})
updates = Nx.reshape(g, {num_elements})
indices = Nx.reshape(i, {num_elements, :auto})

updates_shape =
Enum.reduce(Enum.sort(opts[:axes], :desc), t.shape, fn axis, shape ->
Tuple.delete_at(shape, axis)
end)

updates = Nx.reshape(g, Tuple.insert_at(updates_shape, 0, num_elements))

g = t |> Expr.broadcast(0, t.shape, Nx.axes(t)) |> Nx.indexed_add(indices, updates, opts)
[{t, g}]
Expand Down
134 changes: 96 additions & 38 deletions nx/test/nx/defn/grad_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3355,13 +3355,13 @@ defmodule Nx.Defn.GradTest do
)
end

defn grad_sum_log_power_gather_cos(t, i) do
defn grad_sum_log_power_gather_cos(t, i, opts \\ []) do
grad(
t,
fn t ->
t
|> Nx.cos()
|> Nx.gather(i)
|> Nx.gather(i, opts)
|> Nx.pow(2)
|> Nx.log()
|> Nx.sum()
Expand Down Expand Up @@ -3427,26 +3427,34 @@ defmodule Nx.Defn.GradTest do
])
)

assert Nx.tensor([
[-0.0, -9.34444522857666, 4.370079040527344, -0.0],
[-2.3156425952911377, 6.7610297203063965, 0.0, -0.0],
[13.5994234085083, -0.0, -1.2967215776443481, 1355.705078125]
]) ==
grad_sum_log_power_gather_cos(
Nx.tensor([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]
]),
Nx.tensor([
[
[[0, 0], [0, 1], [0, 2]],
[[2, 0], [1, 0], [0, 1]],
[[0, 1], [1, 1], [2, 2]],
[[2, 3], [2, 3], [2, 3]]
]
])
)
t =
Nx.tensor([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]
])

i =
Nx.tensor([
[
[[0, 0], [0, 1], [0, 2]],
[[2, 0], [1, 0], [0, 1]],
[[0, 1], [1, 1], [2, 2]],
[[2, 3], [2, 3], [2, 3]]
]
])

result =
Nx.tensor([
[-0.0, -9.34444522857666, 4.370079040527344, -0.0],
[-2.3156425952911377, 6.7610297203063965, 0.0, -0.0],
[13.5994234085083, -0.0, -1.2967215776443481, 1355.705078125]
])

assert result == grad_sum_log_power_gather_cos(t, i)

assert Nx.new_axis(result, 1) ==
grad_sum_log_power_gather_cos(Nx.new_axis(t, 1), i, axes: [0, 2])
end
end

Expand Down Expand Up @@ -3555,42 +3563,42 @@ defmodule Nx.Defn.GradTest do
describe "indexed_put" do
defn grad_indexed_put_target(t, i, u), do: grad(t, &Nx.sum(Nx.indexed_put(&1, i, u)))

defn grad_indexed_put_target_composite(t, i, u) do
defn grad_indexed_put_target_composite(t, i, u, opts \\ []) do
grad(t, fn t ->
t
|> Nx.cos()
|> Nx.indexed_put(i, u)
|> Nx.indexed_put(i, u, opts)
|> Nx.sin()
|> Nx.sum()
end)
end

defn grad_indexed_put_updates(t, i, u), do: grad(u, &Nx.sum(Nx.indexed_put(t, i, &1)))

defn grad_indexed_put_updates_composite(t, i, u) do
defn grad_indexed_put_updates_composite(t, i, u, opts \\ []) do
grad(u, fn u ->
t
|> Nx.indexed_put(i, Nx.cos(u))
|> Nx.indexed_put(i, Nx.cos(u), opts)
|> Nx.sin()
|> Nx.sum()
end)
end

defn grad_indexed_put_indices(t, i, u), do: grad(i, &Nx.sum(Nx.indexed_put(t, &1, u)))

defn grad_indexed_put_indices_composite(t, i, u) do
defn grad_indexed_put_indices_composite(t, i, u, opts \\ []) do
grad(i, fn i ->
t
|> Nx.indexed_put(Nx.multiply(i, 2), u)
|> Nx.indexed_put(Nx.multiply(i, 2), u, opts)
|> Nx.sin()
|> Nx.sum()
end)
end

defn grad_indexed_put_simultaneous_composite(t, i) do
defn grad_indexed_put_simultaneous_composite(t, i, opts \\ []) do
grad(t, fn t ->
t
|> Nx.indexed_put(i, Nx.cos(t))
|> Nx.indexed_put(i, Nx.cos(t), opts)
|> Nx.sin()
|> Nx.sum()
end)
Expand Down Expand Up @@ -3618,6 +3626,16 @@ defmodule Nx.Defn.GradTest do
]),
grad_indexed_put_target_composite(t, i, u)
)

assert_all_close(
Nx.tensor([
[0, 0, -0.8316, -0.0774],
[0, 0.9206, 0.1602, -0.4789],
[-0.9789, -0.2525, 0, 0]
])
|> Nx.new_axis(1),
grad_indexed_put_target_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2])
)
end

test "grad wrt to source" do
Expand All @@ -3631,6 +3649,11 @@ defmodule Nx.Defn.GradTest do
# f'(x) = cos(cos(x)) * (-sin(x))
expected = u |> Nx.cos() |> Nx.cos() |> Nx.multiply(Nx.sin(u)) |> Nx.negate()
assert_all_close(expected, grad_indexed_put_updates_composite(t, i, u))

assert_all_close(
Nx.new_axis(expected, 1),
grad_indexed_put_updates_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2])
)
end

test "grad wrt to indices" do
Expand All @@ -3640,6 +3663,11 @@ defmodule Nx.Defn.GradTest do

assert_all_close(Nx.broadcast(0, i), grad_indexed_put_indices(t, i, u))
assert_all_close(Nx.broadcast(0, i), grad_indexed_put_indices_composite(t, i, u))

assert_all_close(
Nx.broadcast(0, i),
grad_indexed_put_indices_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2])
)
end

test "grad wrt to both source and target simultaneously" do
Expand All @@ -3653,6 +3681,11 @@ defmodule Nx.Defn.GradTest do
expected = t |> Nx.cos() |> Nx.cos() |> Nx.multiply(Nx.sin(t)) |> Nx.negate()

assert_all_close(expected, grad_indexed_put_simultaneous_composite(t, i))

assert_all_close(
Nx.new_axis(expected, 1),
grad_indexed_put_simultaneous_composite(Nx.new_axis(t, 1), i, axes: [0])
)
end
end

Expand All @@ -3666,11 +3699,11 @@ defmodule Nx.Defn.GradTest do
end)
end

defn grad_indexed_add_target_composite(t, i, u) do
defn grad_indexed_add_target_composite(t, i, u, opts \\ []) do
grad(t, fn t ->
t
|> Nx.cos()
|> Nx.indexed_add(i, u)
|> Nx.indexed_add(i, u, opts)
|> Nx.sin()
|> Nx.sum()
end)
Expand All @@ -3685,30 +3718,30 @@ defmodule Nx.Defn.GradTest do
end)
end

defn grad_indexed_add_updates_composite(t, i, u) do
defn grad_indexed_add_updates_composite(t, i, u, opts \\ []) do
grad(u, fn u ->
t
|> Nx.indexed_add(i, Nx.cos(u))
|> Nx.indexed_add(i, Nx.cos(u), opts)
|> Nx.sin()
|> Nx.sum()
end)
end

defn grad_indexed_add_indices(t, i, u), do: grad(i, &Nx.sum(Nx.indexed_add(t, &1, u)))

defn grad_indexed_add_indices_composite(t, i, u) do
defn grad_indexed_add_indices_composite(t, i, u, opts \\ []) do
grad(i, fn i ->
t
|> Nx.indexed_add(Nx.multiply(i, 2), u)
|> Nx.indexed_add(Nx.multiply(i, 2), u, opts)
|> Nx.sin()
|> Nx.sum()
end)
end

defn grad_indexed_add_simultaneous_composite(t, i) do
defn grad_indexed_add_simultaneous_composite(t, i, opts \\ []) do
grad(t, fn t ->
t
|> Nx.indexed_add(i, Nx.cos(t))
|> Nx.indexed_add(i, Nx.cos(t), opts)
|> Nx.sin()
|> Nx.sum()
end)
Expand Down Expand Up @@ -3739,6 +3772,16 @@ defmodule Nx.Defn.GradTest do
]),
grad_indexed_add_target_composite(t, i, u)
)

assert_all_close(
Nx.tensor([
[0, -0.0932, -0.8316, -0.0774],
[0.1684, 0.9206, 0.1602, -0.4789],
[-0.9789, -0.2525, -0.1442, -0.9905]
])
|> Nx.new_axis(1),
grad_indexed_add_target_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2])
)
end

test "grad wrt to source" do
Expand All @@ -3756,6 +3799,11 @@ defmodule Nx.Defn.GradTest do
expected = cosx_tn |> Nx.cos() |> Nx.multiply(Nx.sin(u)) |> Nx.negate()

assert_all_close(expected, grad_indexed_add_updates_composite(t, i, u))

assert_all_close(
Nx.new_axis(expected, 1),
grad_indexed_add_updates_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2])
)
end

test "grad wrt to indices" do
Expand All @@ -3765,6 +3813,11 @@ defmodule Nx.Defn.GradTest do

assert_all_close(Nx.broadcast(0, i), grad_indexed_add_indices(t, i, u))
assert_all_close(Nx.broadcast(0, i), grad_indexed_add_indices_composite(t, i, u))

assert_all_close(
Nx.broadcast(0, i),
grad_indexed_add_indices_composite(Nx.new_axis(t, 1), i, Nx.new_axis(u, 1), axes: [0, 2])
)
end

test "grad wrt to both source and target simultaneously" do
Expand All @@ -3779,6 +3832,11 @@ defmodule Nx.Defn.GradTest do
expected = cosx_tn |> Nx.cos() |> Nx.multiply(Nx.subtract(1, Nx.sin(t)))

assert_all_close(expected, grad_indexed_add_simultaneous_composite(t, i))

assert_all_close(
Nx.new_axis(expected, 1),
grad_indexed_add_simultaneous_composite(Nx.new_axis(t, 1), i, axes: [0])
)
end
end

Expand Down
Loading