Skip to content

Commit

Permalink
fix: stop_grad for logsumexp (#1470)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Apr 6, 2024
1 parent bb056ce commit adc6cd0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
10 changes: 10 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -17332,6 +17332,16 @@ defmodule Nx do
axes = opts[:axes]
keep_axes = opts[:keep_axes]
max = reduce_max(tensor, axes: axes, keep_axes: true)

max =
case max do
%T{data: %Nx.Defn.Expr{}} = t ->
Nx.Defn.Kernel.stop_grad(t)

t ->
t
end

infinity_mask = is_infinity(max)
max = select(infinity_mask, Nx.tensor(0, type: type), max)
exponentials = tensor |> subtract(max) |> exp()
Expand Down
10 changes: 10 additions & 0 deletions nx/test/nx/defn/expr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ defmodule Nx.Defn.ExprTest do
named = Nx.tensor([4], names: [:dim])
assert %T{type: {:f, 32}, names: [:dim]} = Nx.multiply(Expr.tensor(named), Expr.tensor(1.0))
end

test "logsumexp" do
expr = Nx.logsumexp(Expr.tensor(Nx.tensor([1, 2, 3, 4, 5, 6])))

assert inspect(expr) =~ """
tensor a s64[6]
b = reduce_max a, axes: [0], keep_axes: true s64[1]
c = metadata b, :stop_grad s64[1]
"""
end
end

describe "inspect" do
Expand Down

0 comments on commit adc6cd0

Please sign in to comment.