From 81a7bb7a1f396725fe2bef0d85d5ec634b634f13 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:26:02 -0300 Subject: [PATCH] fix: least_squares implementation (#1550) --- exla/test/exla/nx_linalg_doctest_test.exs | 2 +- nx/lib/nx/lin_alg.ex | 24 +++++++++++------------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/exla/test/exla/nx_linalg_doctest_test.exs b/exla/test/exla/nx_linalg_doctest_test.exs index 6df3aeec102..10c2cbce059 100644 --- a/exla/test/exla/nx_linalg_doctest_test.exs +++ b/exla/test/exla/nx_linalg_doctest_test.exs @@ -10,7 +10,7 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do invert: 1, matrix_power: 2 ] - @rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 2] + @rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 3] @excluded_doctests @function_clause_error_doctests ++ @rounding_error_doctests ++ diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex index 0a353f43205..450a48cae42 100644 --- a/nx/lib/nx/lin_alg.ex +++ b/nx/lib/nx/lin_alg.ex @@ -2152,12 +2152,16 @@ defmodule Nx.LinAlg do @doc """ Return the least-squares solution to a linear matrix equation Ax = b. + ## Options + + * `:eps` - Rounding error threshold used to assume values as 0. Defaults to `1.0e-15` + ## Examples iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2])) #Nx.Tensor< f32[2] - [1.0000004768371582, -2.665601925855299e-7] + [0.9977624416351318, 0.0011188983917236328] > iex> Nx.LinAlg.least_squares(Nx.tensor([[0, 1], [1, 1], [2, 1], [3, 1]]), Nx.tensor([-1, 0.2, 0.9, 2.1])) @@ -2187,7 +2191,9 @@ defmodule Nx.LinAlg do ** (ArgumentError) the number of rows of the matrix as the 1st argument and the number of columns of the vector as the 2nd argument must be the same, got 1st argument shape {2, 2} and 2nd argument shape {3} """ @doc from_backend: false - defn least_squares(a, b) do + defn least_squares(a, b, opts \\ []) do + opts = keyword!(opts, eps: 1.0e-15) + %T{type: a_type, shape: a_shape} = Nx.to_tensor(a) a_size = Nx.rank(a_shape) %T{type: b_type, shape: b_shape} = Nx.to_tensor(b) @@ -2235,17 +2241,9 @@ defmodule Nx.LinAlg do ) end - case a_shape do - {m, n} when m == n -> - Nx.LinAlg.solve(a, b) - - {m, n} when m != n -> - Nx.LinAlg.pinv(a, eps: 1.0e-15) - |> Nx.dot(b) - - _ -> - nil - end + a + |> Nx.LinAlg.pinv(eps: opts[:eps]) + |> Nx.dot(b) end defp apply_vectorized(tensor, fun) when is_function(fun, 1) do