Skip to content

Commit

Permalink
linalg solve backend
Browse files Browse the repository at this point in the history
  • Loading branch information
abeleinin authored and awni committed Oct 25, 2024
1 parent 8e88e30 commit f78f6d4
Show file tree
Hide file tree
Showing 15 changed files with 336 additions and 4 deletions.
5 changes: 3 additions & 2 deletions docs/src/python/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ Linear Algebra

.. currentmodule:: mlx.core.linalg

.. autosummary::
:toctree: _autosummary
.. autosummary::
:toctree: _autosummary

inv
tri_inv
Expand All @@ -18,3 +18,4 @@ Linear Algebra
svd
eigvalsh
eigh
solve
1 change: 1 addition & 0 deletions mlx/backend/accelerate/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
DEFAULT_MULTI(Solve)

void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_SOURCE_DIR}/solve.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)

Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/default_primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
DEFAULT_MULTI(Solve)

namespace {

Expand Down
131 changes: 131 additions & 0 deletions mlx/backend/common/solve.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright © 2024 Apple Inc.

#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h"
#include "mlx/primitives.h"

#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif

#include <cassert>

namespace mlx::core {

namespace {

// Wrapper to account for differences in
// LAPACK implementations (basically how to pass the 'trans' string to fortran).
int sgetrs_wrapper(char trans, int N, int NRHS, int* ipiv, float* a, float* b) {
int info;

#ifdef LAPACK_FORTRAN_STRLEN_END
sgetrs_(
/* trans */ &trans,
/* n */ &N,
/* nrhs */ &NRHS,
/* a */ a,
/* lda */ &N,
/* ipiv */ ipiv,
/* b */ b,
/* ldb */ &N,
/* info */ &info,
/* trans_len = */ static_cast<size_t>(1));
#else
sgetrs_(
/* trans */ &trans,
/* n */ &N,
/* nrhs */ &NRHS,
/* a */ a,
/* lda */ &N,
/* ipiv */ ipiv,
/* b */ b,
/* ldb */ &N,
/* info */ &info);
#endif

return info;
}

} // namespace

void solve_impl(const array& a, const array& b, array& out) {
int N = a.shape(-2);
int NRHS = out.shape(-1);
std::vector<int> ipiv(N);

// copy b into out and make it col-contiguous
auto flags = out.flags();
flags.col_contiguous = true;
flags.row_contiguous = false;
std::vector<size_t> strides(a.ndim(), 0);
std::copy(out.strides().begin(), out.strides().end(), strides.begin());
strides[a.ndim() - 2] = 1;
strides[a.ndim() - 1] = N;

out.set_data(
allocator::malloc_or_wait(out.nbytes()), out.nbytes(), strides, flags);
copy_inplace(b, out, CopyType::GeneralGeneral);

// lapack clobbers the input, so we have to make a copy. the copy doesn't need
// to be col-contiguous because sgetrs has a transpose parameter (trans='T').
array a_cpy(a.shape(), float32, nullptr, {});
copy(
a,
a_cpy,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);

float* a_ptr = a_cpy.data<float>();
float* out_ptr = out.data<float>();
int* ipiv_ptr = ipiv.data();

int info;
size_t num_matrices = a.size() / (N * N);
for (size_t i = 0; i < num_matrices; i++) {
// Compute LU factorization of A
MLX_LAPACK_FUNC(sgetrf)
(/* m */ &N,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &N,
/* ipiv */ ipiv_ptr,
/* info */ &info);

if (info != 0) {
std::stringstream ss;
ss << "solve_impl: sgetrf_ failed with code " << info
<< ((info > 0) ? " because matrix is singular"
: " becuase argument had an illegal value");
throw std::runtime_error(ss.str());
}

static constexpr char trans = 'T';
// Solve the system using the LU factors from sgetrf
info = sgetrs_wrapper(trans, N, NRHS, ipiv_ptr, a_ptr, out_ptr);

if (info != 0) {
std::stringstream ss;
ss << "solve_impl: sgetrs_ failed with code " << info;
throw std::runtime_error(ss.str());
}

// Advance pointers to the next matrix
a_ptr += N * N;
out_ptr += N * NRHS;
}
}

void Solve::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 2);
if (inputs[0].dtype() != float32 || inputs[1].dtype() != float32) {
throw std::runtime_error("[Solve::eval] only supports float32.");
}
solve_impl(inputs[0], inputs[1], outputs[0]);
}

} // namespace mlx::core
6 changes: 6 additions & 0 deletions mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,4 +438,10 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}

void Solve::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[Solve::eval_gpu] Metal Solve NYI.");
}

} // namespace mlx::core
1 change: 1 addition & 0 deletions mlx/backend/no_cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,6 @@ NO_CPU(Tanh)
NO_CPU(Transpose)
NO_CPU(Inverse)
NO_CPU(View)
NO_CPU_MULTI(Solve)

} // namespace mlx::core
1 change: 1 addition & 0 deletions mlx/backend/no_metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eigh)
NO_GPU(View)
NO_GPU_MULTI(Solve)

namespace fast {
NO_GPU_MULTI(LayerNorm)
Expand Down
39 changes: 39 additions & 0 deletions mlx/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,4 +500,43 @@ std::pair<array, array> eigh(
return std::make_pair(out[0], out[1]);
}

array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (a.dtype() != float32 && b.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::solve] Input array must have type float32. Received arrays "
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
}

if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::solve] First input must have >= 2 dimensions. "
<< "Received array with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (b.ndim() < 1) {
std::ostringstream msg;
msg << "[linalg::solve] Second input must have >= 1 dimensions. "
<< "Received array with " << b.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != a.shape(-2)) {
std::ostringstream msg;
msg << "[linalg::solve] First input must be a square matrix. "
<< "Received array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != b.shape(b.ndim() - 2)) {
std::ostringstream msg;
msg << "[linalg::solve] Last dimension of first input with shape "
<< a.shape() << " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
return array(
b.shape(), out_type, std::make_shared<Solve>(to_stream(s)), {a, b});
}

} // namespace mlx::core::linalg
2 changes: 2 additions & 0 deletions mlx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ array pinv(const array& a, StreamOrDevice s = {});

array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});

array solve(const array& a, const array& b, StreamOrDevice s = {});

/**
* Compute the cross product of two arrays along the given axis.
*/
Expand Down
11 changes: 11 additions & 0 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4201,6 +4201,17 @@ std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
}

std::pair<std::vector<array>, std::vector<int>> Solve::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto maybe_move_ax = [this](auto& arr, auto ax) {
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
};
auto a = maybe_move_ax(inputs[0], axes[0]);
auto b = maybe_move_ax(inputs[1], axes[1]);
return {{linalg::solve(a, b, stream())}, {0}};
}

std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
Expand Down
17 changes: 16 additions & 1 deletion mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -2202,7 +2202,6 @@ class Eigh : public Primitive {
: Primitive(stream),
uplo_(std::move(uplo)),
compute_eigenvectors_(compute_eigenvectors) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
Expand Down Expand Up @@ -2236,4 +2235,20 @@ class Eigh : public Primitive {
bool compute_eigenvectors_;
};

class Solve : public Primitive {
public:
explicit Solve(Stream stream) : Primitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;

DEFINE_VMAP()
DEFINE_PRINT(Solve)
DEFINE_DEFAULT_IS_EQUIVALENT()

private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};

} // namespace mlx::core
20 changes: 19 additions & 1 deletion python/src/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,6 @@ void init_linalg(nb::module_& parent_module) {
m.def(
"eigh",
[](const array& a, const std::string UPLO, StreamOrDevice s) {
// TODO avoid cast?
auto result = eigh(a, UPLO, s);
return nb::make_tuple(result.first, result.second);
},
Expand Down Expand Up @@ -486,4 +485,23 @@ void init_linalg(nb::module_& parent_module) {
array([[ 0.707107, -0.707107],
[ 0.707107, 0.707107]], dtype=float32)
)pbdoc");
m.def(
"solve",
&solve,
"a"_a,
"b"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def solve(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the solution to a system of linear equations ax = b.
Args:
a (array): Input array.
b (array): Input array.
Returns:
array: The unique solution to the system ax = b.
)pbdoc");
}
54 changes: 54 additions & 0 deletions python/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,60 @@ def check_eigs_and_vecs(A_np, kwargs={}):
mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
) # Non-square matrix

def test_solve(self):
mx.random.seed(7)

# Test 3x3 matrix with 1D rhs
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])
b = mx.array([11.0, 35.0, 28.0])

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected))

# Test symmetric positive-definite matrix
N = 5
a = mx.random.uniform(shape=(N, N))
a = mx.matmul(a, a.T) + N * mx.eye(N)
b = mx.random.uniform(shape=(N, 1))

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-5))

# Test batch dimension
a = mx.random.uniform(shape=(5, 5, 4, 4))
b = mx.random.uniform(shape=(5, 5, 4, 1))

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-5))

# Test large matrix
N = 1000
a = mx.random.uniform(shape=(N, N))
b = mx.random.uniform(shape=(N, 1))

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-2))

# Test multi-column rhs
a = mx.random.uniform(shape=(5, 5))
b = mx.random.uniform(shape=(5, 8))

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-5))

# Test batched multi-column rhs
a = mx.concat([a, a, a, a, a, a]).reshape((3, 2, 5, 5))
b = mx.concat([b, b, b, b, b, b]).reshape((3, 2, 5, 8))

result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-5))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit f78f6d4

Please sign in to comment.