Skip to content

Commit

Permalink
fix: Nx.LinAlg.norm axes support
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Aug 16, 2024
1 parent 7a3d7cd commit e85f262
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
23 changes: 17 additions & 6 deletions nx/lib/nx/lin_alg.ex
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,10 @@ defmodule Nx.LinAlg do
# The idea is that by dividing the tensor by it, large values of
# tensor entries and large values of p are reduced, which in turn
# avoids numerical overflow.

keep_axes = opts[:keep_axes]

opts = Keyword.put(opts, :keep_axes, true)
numerical_stability_coefficient = Nx.reduce_max(abs_t, opts)

# This code prevents from division by zero.
Expand All @@ -398,12 +402,19 @@ defmodule Nx.LinAlg do
1
)

abs_t
|> Nx.divide(numerical_stability_coefficient)
|> Nx.pow(ord)
|> Nx.sum(opts)
|> Nx.pow(inv_ord)
|> Nx.multiply(numerical_stability_coefficient)
result =
abs_t
|> Nx.divide(numerical_stability_coefficient)
|> Nx.pow(ord)
|> Nx.sum(opts)
|> Nx.pow(inv_ord)
|> Nx.multiply(numerical_stability_coefficient)

if keep_axes do
result
else
Nx.squeeze(result, Keyword.take(opts, [:axes]))
end
end

@doc """
Expand Down
13 changes: 13 additions & 0 deletions nx/test/nx/lin_alg_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,19 @@ defmodule Nx.LinAlgTest do
Nx.LinAlg.norm(t, ord: -3)
end)
end

test "correctly support axes option" do
t =
Nx.tensor([
[-1.0, -1.0],
[0.0, 0.0],
[1.0, 1.0]
])

result = Nx.tensor([1.4142135381698608, 0.0, 1.4142135381698608])
assert Nx.LinAlg.norm(t, axes: [1]) == result
assert Nx.LinAlg.norm(t, axes: [1], keep_axes: true) == Nx.reshape(result, {3, 1})
end
end

describe "matrix_power" do
Expand Down

0 comments on commit e85f262

Please sign in to comment.