From 93abe4854534a3a0bd9c5493dc058be702d53614 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 28 Oct 2024 22:09:44 -0300 Subject: [PATCH 1/2] fix: least_squares implementation --- nx/lib/nx/lin_alg.ex | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex index 0a353f4320..450a48cae4 100644 --- a/nx/lib/nx/lin_alg.ex +++ b/nx/lib/nx/lin_alg.ex @@ -2152,12 +2152,16 @@ defmodule Nx.LinAlg do @doc """ Return the least-squares solution to a linear matrix equation Ax = b. + ## Options + + * `:eps` - Rounding error threshold used to assume values as 0. Defaults to `1.0e-15` + ## Examples iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2])) #Nx.Tensor< f32[2] - [1.0000004768371582, -2.665601925855299e-7] + [0.9977624416351318, 0.0011188983917236328] > iex> Nx.LinAlg.least_squares(Nx.tensor([[0, 1], [1, 1], [2, 1], [3, 1]]), Nx.tensor([-1, 0.2, 0.9, 2.1])) @@ -2187,7 +2191,9 @@ defmodule Nx.LinAlg do ** (ArgumentError) the number of rows of the matrix as the 1st argument and the number of columns of the vector as the 2nd argument must be the same, got 1st argument shape {2, 2} and 2nd argument shape {3} """ @doc from_backend: false - defn least_squares(a, b) do + defn least_squares(a, b, opts \\ []) do + opts = keyword!(opts, eps: 1.0e-15) + %T{type: a_type, shape: a_shape} = Nx.to_tensor(a) a_size = Nx.rank(a_shape) %T{type: b_type, shape: b_shape} = Nx.to_tensor(b) @@ -2235,17 +2241,9 @@ defmodule Nx.LinAlg do ) end - case a_shape do - {m, n} when m == n -> - Nx.LinAlg.solve(a, b) - - {m, n} when m != n -> - Nx.LinAlg.pinv(a, eps: 1.0e-15) - |> Nx.dot(b) - - _ -> - nil - end + a + |> Nx.LinAlg.pinv(eps: opts[:eps]) + |> Nx.dot(b) end defp apply_vectorized(tensor, fun) when is_function(fun, 1) do From 4bba06fdb5f7fab25146d1d91b17b88f7a26cc8b Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 28 Oct 2024 22:28:43 -0300 Subject: [PATCH 2/2] fix: lsq arity in test skip --- exla/test/exla/nx_linalg_doctest_test.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exla/test/exla/nx_linalg_doctest_test.exs b/exla/test/exla/nx_linalg_doctest_test.exs index 6df3aeec10..10c2cbce05 100644 --- a/exla/test/exla/nx_linalg_doctest_test.exs +++ b/exla/test/exla/nx_linalg_doctest_test.exs @@ -10,7 +10,7 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do invert: 1, matrix_power: 2 ] - @rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 2] + @rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 3] @excluded_doctests @function_clause_error_doctests ++ @rounding_error_doctests ++