diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index f6c51ed0b..853ad393b 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -5,8 +5,8 @@ Linear Algebra .. currentmodule:: mlx.core.linalg -.. autosummary:: - :toctree: _autosummary +.. autosummary:: + :toctree: _autosummary inv tri_inv @@ -18,3 +18,4 @@ Linear Algebra svd eigvalsh eigh + solve diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 1f80224ad..69f7eadca 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -82,6 +82,7 @@ DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) DEFAULT_MULTI(Eigh) +DEFAULT_MULTI(Solve) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 4fca2274e..7a77f82e6 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -52,6 +52,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/solve.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 547d8e25d..26c745537 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -111,6 +111,7 @@ DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) DEFAULT_MULTI(Eigh) +DEFAULT_MULTI(Solve) namespace { diff --git a/mlx/backend/common/solve.cpp b/mlx/backend/common/solve.cpp new file mode 100644 index 000000000..be0f190d8 --- /dev/null +++ b/mlx/backend/common/solve.cpp @@ -0,0 +1,131 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack_helper.h" +#include "mlx/primitives.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +#include + +namespace mlx::core { + +namespace { + +// Wrapper to account for differences in +// LAPACK implementations (basically how to pass the 'trans' string to fortran). +int sgetrs_wrapper(char trans, int N, int NRHS, int* ipiv, float* a, float* b) { + int info; + +#ifdef LAPACK_FORTRAN_STRLEN_END + sgetrs_( + /* trans */ &trans, + /* n */ &N, + /* nrhs */ &NRHS, + /* a */ a, + /* lda */ &N, + /* ipiv */ ipiv, + /* b */ b, + /* ldb */ &N, + /* info */ &info, + /* trans_len = */ static_cast(1)); +#else + sgetrs_( + /* trans */ &trans, + /* n */ &N, + /* nrhs */ &NRHS, + /* a */ a, + /* lda */ &N, + /* ipiv */ ipiv, + /* b */ b, + /* ldb */ &N, + /* info */ &info); +#endif + + return info; +} + +} // namespace + +void solve_impl(const array& a, const array& b, array& out) { + int N = a.shape(-2); + int NRHS = out.shape(-1); + std::vector ipiv(N); + + // copy b into out and make it col-contiguous + auto flags = out.flags(); + flags.col_contiguous = true; + flags.row_contiguous = false; + std::vector strides(a.ndim(), 0); + std::copy(out.strides().begin(), out.strides().end(), strides.begin()); + strides[a.ndim() - 2] = 1; + strides[a.ndim() - 1] = N; + + out.set_data( + allocator::malloc_or_wait(out.nbytes()), out.nbytes(), strides, flags); + copy_inplace(b, out, CopyType::GeneralGeneral); + + // lapack clobbers the input, so we have to make a copy. the copy doesn't need + // to be col-contiguous because sgetrs has a transpose parameter (trans='T'). + array a_cpy(a.shape(), float32, nullptr, {}); + copy( + a, + a_cpy, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + + float* a_ptr = a_cpy.data(); + float* out_ptr = out.data(); + int* ipiv_ptr = ipiv.data(); + + int info; + size_t num_matrices = a.size() / (N * N); + for (size_t i = 0; i < num_matrices; i++) { + // Compute LU factorization of A + MLX_LAPACK_FUNC(sgetrf) + (/* m */ &N, + /* n */ &N, + /* a */ a_ptr, + /* lda */ &N, + /* ipiv */ ipiv_ptr, + /* info */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "solve_impl: sgetrf_ failed with code " << info + << ((info > 0) ? " because matrix is singular" + : " becuase argument had an illegal value"); + throw std::runtime_error(ss.str()); + } + + static constexpr char trans = 'T'; + // Solve the system using the LU factors from sgetrf + info = sgetrs_wrapper(trans, N, NRHS, ipiv_ptr, a_ptr, out_ptr); + + if (info != 0) { + std::stringstream ss; + ss << "solve_impl: sgetrs_ failed with code " << info; + throw std::runtime_error(ss.str()); + } + + // Advance pointers to the next matrix + a_ptr += N * N; + out_ptr += N * NRHS; + } +} + +void Solve::eval( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 2); + if (inputs[0].dtype() != float32 || inputs[1].dtype() != float32) { + throw std::runtime_error("[Solve::eval] only supports float32."); + } + solve_impl(inputs[0], inputs[1], outputs[0]); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index e5a7d885b..35586a257 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -438,4 +438,10 @@ void View::eval_gpu(const std::vector& inputs, array& out) { } } +void Solve::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[Solve::eval_gpu] Metal Solve NYI."); +} + } // namespace mlx::core diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index c87fcc8bb..db9f40013 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -111,5 +111,6 @@ NO_CPU(Tanh) NO_CPU(Transpose) NO_CPU(Inverse) NO_CPU(View) +NO_CPU_MULTI(Solve) } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index aaee51d83..f8556a76d 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -114,6 +114,7 @@ NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eigh) NO_GPU(View) +NO_GPU_MULTI(Solve) namespace fast { NO_GPU_MULTI(LayerNorm) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index daf5573fc..74c052366 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -500,4 +500,43 @@ std::pair eigh( return std::make_pair(out[0], out[1]); } +array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { + if (a.dtype() != float32 && b.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::solve] Input array must have type float32. Received arrays " + << "with type " << a.dtype() << " and " << b.dtype() << "."; + } + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::solve] First input must have >= 2 dimensions. " + << "Received array with " << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (b.ndim() < 1) { + std::ostringstream msg; + msg << "[linalg::solve] Second input must have >= 1 dimensions. " + << "Received array with " << b.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + std::ostringstream msg; + msg << "[linalg::solve] First input must be a square matrix. " + << "Received array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != b.shape(b.ndim() - 2)) { + std::ostringstream msg; + msg << "[linalg::solve] Last dimension of first input with shape " + << a.shape() << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + return array( + b.shape(), out_type, std::make_shared(to_stream(s)), {a, b}); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index 4ea81bef0..bca1d56af 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -74,6 +74,8 @@ array pinv(const array& a, StreamOrDevice s = {}); array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); +array solve(const array& a, const array& b, StreamOrDevice s = {}); + /** * Compute the cross product of two arrays along the given axis. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index c9f839d4b..9ea5cc513 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4201,6 +4201,17 @@ std::pair, std::vector> SVD::vmap( return {{linalg::svd(a, stream())}, {ax, ax, ax}}; } +std::pair, std::vector> Solve::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto maybe_move_ax = [this](auto& arr, auto ax) { + return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr; + }; + auto a = maybe_move_ax(inputs[0], axes[0]); + auto b = maybe_move_ax(inputs[1], axes[1]); + return {{linalg::solve(a, b, stream())}, {0}}; +} + std::pair, std::vector> Inverse::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index f2b5bab7c..f15d66f40 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2202,7 +2202,6 @@ class Eigh : public Primitive { : Primitive(stream), uplo_(std::move(uplo)), compute_eigenvectors_(compute_eigenvectors) {} - void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -2236,4 +2235,20 @@ class Eigh : public Primitive { bool compute_eigenvectors_; }; +class Solve : public Primitive { + public: + explicit Solve(Stream stream) : Primitive(stream) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_PRINT(Solve) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, std::vector& outputs); +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index e2c3aea23..0510f0f69 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -443,7 +443,6 @@ void init_linalg(nb::module_& parent_module) { m.def( "eigh", [](const array& a, const std::string UPLO, StreamOrDevice s) { - // TODO avoid cast? auto result = eigh(a, UPLO, s); return nb::make_tuple(result.first, result.second); }, @@ -486,4 +485,23 @@ void init_linalg(nb::module_& parent_module) { array([[ 0.707107, -0.707107], [ 0.707107, 0.707107]], dtype=float32) )pbdoc"); + m.def( + "solve", + &solve, + "a"_a, + "b"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def solve(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the solution to a system of linear equations ax = b. + + Args: + a (array): Input array. + b (array): Input array. + + Returns: + array: The unique solution to the system ax = b. + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 695d7704f..f81186818 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -319,6 +319,60 @@ def check_eigs_and_vecs(A_np, kwargs={}): mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) ) # Non-square matrix + def test_solve(self): + mx.random.seed(7) + + # Test 3x3 matrix with 1D rhs + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]) + b = mx.array([11.0, 35.0, 28.0]) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test symmetric positive-definite matrix + N = 5 + a = mx.random.uniform(shape=(N, N)) + a = mx.matmul(a, a.T) + N * mx.eye(N) + b = mx.random.uniform(shape=(N, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + + # Test batch dimension + a = mx.random.uniform(shape=(5, 5, 4, 4)) + b = mx.random.uniform(shape=(5, 5, 4, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + + # Test large matrix + N = 1000 + a = mx.random.uniform(shape=(N, N)) + b = mx.random.uniform(shape=(N, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-2)) + + # Test multi-column rhs + a = mx.random.uniform(shape=(5, 5)) + b = mx.random.uniform(shape=(5, 8)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + + # Test batched multi-column rhs + a = mx.concat([a, a, a, a, a, a]).reshape((3, 2, 5, 5)) + b = mx.concat([b, b, b, b, b, b]).reshape((3, 2, 5, 8)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index f0b34cc01..a8b03ee14 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -473,3 +473,53 @@ TEST_CASE("test matrix eigh") { // Verify eigendecomposition CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item()); } + +TEST_CASE("test solve") { + // 0D and 1D throw + CHECK_THROWS(linalg::solve(array(0.), array(0.), Device::cpu)); + CHECK_THROWS(linalg::solve(array({0.}), array({0.}), Device::cpu)); + + // Unsupported types throw + CHECK_THROWS( + linalg::solve(array({0, 1, 1, 2}, {2, 2}), array({1, 3}), Device::cpu)); + + // Non-square throws + array a = reshape(arange(6), {3, 2}); + array b = reshape(arange(3), {3, 1}); + CHECK_THROWS(linalg::solve(a, b, Device::cpu)); + + // Test 2x2 matrix with 1D rhs + a = array({2., 1., 1., 3.}, {2, 2}); + b = array({8., 13.}, {2}); + + array result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test 3x3 matrix + a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3}); + b = array({6., 15., 25.}, {3, 1}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test batch dimension + a = reshape(concatenate({a, a, a, a, a}), {5, 3, 3}); + b = reshape(concatenate({b, b, b, b, b}), {5, 3, 1}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test multi-column rhs + a = array({2., 1., 1., 1., 3., 2., 1., 0., 0.}, {3, 3}); + b = array({4., 2., 5., 3., 6., 1.}, {3, 2}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test batch multi-column rhs + a = reshape(concatenate({a, a, a, a, a}), {5, 3, 3}); + b = reshape(concatenate({b, b, b, b, b}), {5, 3, 2}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); +}