Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: least_squares implementation #1550

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion exla/test/exla/nx_linalg_doctest_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do
invert: 1,
matrix_power: 2
]
@rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 2]
@rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 3]

@excluded_doctests @function_clause_error_doctests ++
@rounding_error_doctests ++
Expand Down
24 changes: 11 additions & 13 deletions nx/lib/nx/lin_alg.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2152,12 +2152,16 @@ defmodule Nx.LinAlg do
@doc """
Return the least-squares solution to a linear matrix equation Ax = b.

## Options

* `:eps` - Rounding error threshold used to assume values as 0. Defaults to `1.0e-15`

## Examples

iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2]))
#Nx.Tensor<
f32[2]
[1.0000004768371582, -2.665601925855299e-7]
[0.9977624416351318, 0.0011188983917236328]
>

iex> Nx.LinAlg.least_squares(Nx.tensor([[0, 1], [1, 1], [2, 1], [3, 1]]), Nx.tensor([-1, 0.2, 0.9, 2.1]))
Expand Down Expand Up @@ -2187,7 +2191,9 @@ defmodule Nx.LinAlg do
** (ArgumentError) the number of rows of the matrix as the 1st argument and the number of columns of the vector as the 2nd argument must be the same, got 1st argument shape {2, 2} and 2nd argument shape {3}
"""
@doc from_backend: false
defn least_squares(a, b) do
defn least_squares(a, b, opts \\ []) do
opts = keyword!(opts, eps: 1.0e-15)

%T{type: a_type, shape: a_shape} = Nx.to_tensor(a)
a_size = Nx.rank(a_shape)
%T{type: b_type, shape: b_shape} = Nx.to_tensor(b)
Expand Down Expand Up @@ -2235,17 +2241,9 @@ defmodule Nx.LinAlg do
)
end

case a_shape do
{m, n} when m == n ->
Nx.LinAlg.solve(a, b)

{m, n} when m != n ->
Nx.LinAlg.pinv(a, eps: 1.0e-15)
|> Nx.dot(b)

_ ->
nil
end
a
|> Nx.LinAlg.pinv(eps: opts[:eps])
|> Nx.dot(b)
end

defp apply_vectorized(tensor, fun) when is_function(fun, 1) do
Expand Down
Loading