From a91a97ce4eac3a3d47c07aba5f76c08bd02344f3 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sun, 9 Jun 2024 17:31:30 -0300 Subject: [PATCH] Fix: force hessenberg range to be positive (#1498) --- nx/lib/nx/lin_alg/eigh.ex | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/nx/lib/nx/lin_alg/eigh.ex b/nx/lib/nx/lin_alg/eigh.ex index 388e1dc17d..05673ad2ba 100644 --- a/nx/lib/nx/lin_alg/eigh.ex +++ b/nx/lib/nx/lin_alg/eigh.ex @@ -26,7 +26,17 @@ defmodule Nx.LinAlg.Eigh do } end - defn eigh_matrix(a, opts \\ []) do + defnp eigh_matrix(a, opts \\ []) do + case Nx.shape(a) do + {1, 1} -> + {a, Nx.fill(a, 1)} + + {_, _} -> + eigh_2d(a, opts) + end + end + + defnp eigh_2d(a, opts \\ []) do # The input Hermitian matrix A reduced to Hessenberg matrix H by Householder transform. # Then, by using QR iteration it converges to AQ = QΛ, # where Λ is the diagonal matrix of eigenvalues and the columns of Q are the eigenvectors. @@ -56,7 +66,7 @@ defmodule Nx.LinAlg.Eigh do {eigenvals, eigenvecs} end - defn hessenberg_decomposition(matrix, eps) do + defnp hessenberg_decomposition(matrix, eps) do # The input Hermitian matrix A reduced to Hessenberg matrix H by Householder transform. # Then, by using QR iteration it converges to AQ = QΛ, # where Λ is the diagonal matrix of eigenvalues and the columns of Q are the eigenvectors. @@ -70,7 +80,7 @@ defmodule Nx.LinAlg.Eigh do {{hess, q}, _} = while {{hess = Nx.as_type(matrix, out_type), q = eye}, {eps, column_iota}}, - i <- 0..(n - 2) do + i <- 0..(n - 2)//1 do x = hess[[.., i]] x = Nx.select(column_iota <= i, 0, x) h = QR.householder_reflector(x, i, eps)