Skip to content

Commit

Permalink
Fix: force hessenberg range to be positive (#1498)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Jun 9, 2024
1 parent 970a6d8 commit a91a97c
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions nx/lib/nx/lin_alg/eigh.ex
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,17 @@ defmodule Nx.LinAlg.Eigh do
}
end

defn eigh_matrix(a, opts \\ []) do
defnp eigh_matrix(a, opts \\ []) do
case Nx.shape(a) do
{1, 1} ->
{a, Nx.fill(a, 1)}

{_, _} ->
eigh_2d(a, opts)
end
end

defnp eigh_2d(a, opts \\ []) do
# The input Hermitian matrix A reduced to Hessenberg matrix H by Householder transform.
# Then, by using QR iteration it converges to AQ = QΛ,
# where Λ is the diagonal matrix of eigenvalues and the columns of Q are the eigenvectors.
Expand Down Expand Up @@ -56,7 +66,7 @@ defmodule Nx.LinAlg.Eigh do
{eigenvals, eigenvecs}
end

defn hessenberg_decomposition(matrix, eps) do
defnp hessenberg_decomposition(matrix, eps) do
# The input Hermitian matrix A reduced to Hessenberg matrix H by Householder transform.
# Then, by using QR iteration it converges to AQ = QΛ,
# where Λ is the diagonal matrix of eigenvalues and the columns of Q are the eigenvectors.
Expand All @@ -70,7 +80,7 @@ defmodule Nx.LinAlg.Eigh do

{{hess, q}, _} =
while {{hess = Nx.as_type(matrix, out_type), q = eye}, {eps, column_iota}},
i <- 0..(n - 2) do
i <- 0..(n - 2)//1 do
x = hess[[.., i]]
x = Nx.select(column_iota <= i, 0, x)
h = QR.householder_reflector(x, i, eps)
Expand Down

0 comments on commit a91a97c

Please sign in to comment.