diff --git a/exla/test/exla/nx_linalg_doctest_test.exs b/exla/test/exla/nx_linalg_doctest_test.exs index 6df3aeec10..10c2cbce05 100644 --- a/exla/test/exla/nx_linalg_doctest_test.exs +++ b/exla/test/exla/nx_linalg_doctest_test.exs @@ -10,7 +10,7 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do invert: 1, matrix_power: 2 ] - @rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 2] + @rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 3] @excluded_doctests @function_clause_error_doctests ++ @rounding_error_doctests ++ diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex index 0a353f4320..450a48cae4 100644 --- a/nx/lib/nx/lin_alg.ex +++ b/nx/lib/nx/lin_alg.ex @@ -2152,12 +2152,16 @@ defmodule Nx.LinAlg do @doc """ Return the least-squares solution to a linear matrix equation Ax = b. + ## Options + + * `:eps` - Rounding error threshold used to assume values as 0. Defaults to `1.0e-15` + ## Examples iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2])) #Nx.Tensor< f32[2] - [1.0000004768371582, -2.665601925855299e-7] + [0.9977624416351318, 0.0011188983917236328] > iex> Nx.LinAlg.least_squares(Nx.tensor([[0, 1], [1, 1], [2, 1], [3, 1]]), Nx.tensor([-1, 0.2, 0.9, 2.1])) @@ -2187,7 +2191,9 @@ defmodule Nx.LinAlg do ** (ArgumentError) the number of rows of the matrix as the 1st argument and the number of columns of the vector as the 2nd argument must be the same, got 1st argument shape {2, 2} and 2nd argument shape {3} """ @doc from_backend: false - defn least_squares(a, b) do + defn least_squares(a, b, opts \\ []) do + opts = keyword!(opts, eps: 1.0e-15) + %T{type: a_type, shape: a_shape} = Nx.to_tensor(a) a_size = Nx.rank(a_shape) %T{type: b_type, shape: b_shape} = Nx.to_tensor(b) @@ -2235,17 +2241,9 @@ defmodule Nx.LinAlg do ) end - case a_shape do - {m, n} when m == n -> - Nx.LinAlg.solve(a, b) - - {m, n} when m != n -> - Nx.LinAlg.pinv(a, eps: 1.0e-15) - |> Nx.dot(b) - - _ -> - nil - end + a + |> Nx.LinAlg.pinv(eps: opts[:eps]) + |> Nx.dot(b) end defp apply_vectorized(tensor, fun) when is_function(fun, 1) do