Skip to content

Commit

Permalink
fix(grad): Nx.stack grad should remove the added axis (unbroadcast) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Sep 24, 2024
1 parent 93e4383 commit 8102cd9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
5 changes: 4 additions & 1 deletion nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ defmodule Nx.Defn.Grad do
expr = to_grad |> fun.()

transformed_expr = transform.(expr) |> validate_expr!() |> Nx.devectorize(keep_names: false)

{parents, nodes} = parents_tree(transformed_expr, ids)

to_grad_ids = {to_grad, ids}
Expand Down Expand Up @@ -623,7 +624,9 @@ defmodule Nx.Defn.Grad do
current_limit = 1 + limit
start = List.replace_at(zero_axes, axis, limit)
len = List.replace_at(ans_shape_list, axis, 1)
{{t, Nx.slice(g, start, len)}, current_limit}
g = Nx.slice(g, start, len)
g = Nx.squeeze(g, axes: [axis])
{{t, g}, current_limit}
end)

pairs
Expand Down
18 changes: 18 additions & 0 deletions nx/test/nx/defn/grad_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,24 @@ defmodule Nx.Defn.GradTest do
end
end

describe "stack" do
test "works on compound functions for more than 1 axis" do
# This is a test that ensures that the added axis from the
# stack operation is correctly squeezed back out by
# the gradient computation.
x = 2.0

assert grad(Nx.tensor([[x]]), fn t ->
a = Nx.pow(t, 2)
b = Nx.pow(t, 3)
c = Nx.pow(t, 4)

Nx.stack([a, b, c], axis: 1)
|> Nx.sum()
end) == Nx.tensor([[2 * x + 3 * x ** 2 + 4 * x ** 3]])
end
end

describe "cholesky" do
defn cholesky_grad(t) do
grad(t, fn x -> x |> Nx.LinAlg.cholesky() |> Nx.sum() end)
Expand Down

0 comments on commit 8102cd9

Please sign in to comment.