From 45dbbca35aa51c82efbb74cf90298c283535d828 Mon Sep 17 00:00:00 2001 From: aleinin <95333017+abeleinin@users.noreply.github.com> Date: Wed, 2 Oct 2024 23:32:49 -0500 Subject: [PATCH] linalg solve backend --- docs/src/python/linalg.rst | 1 + mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/CMakeLists.txt | 1 + mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/solve.cpp | 131 ++++++++++++++++++++++ mlx/backend/metal/primitives.cpp | 6 + mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_metal/primitives.cpp | 1 + mlx/linalg.cpp | 43 +++++++ mlx/linalg.h | 2 + mlx/primitives.cpp | 11 ++ mlx/primitives.h | 16 +++ python/src/linalg.cpp | 21 ++++ python/tests/test_linalg.py | 54 +++++++++ tests/linalg_tests.cpp | 50 +++++++++ 15 files changed, 340 insertions(+) create mode 100644 mlx/backend/common/solve.cpp diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 227711c227..3c2ab86832 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -16,3 +16,4 @@ Linear Algebra cross qr svd + solve diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index eee93f2ab2..d39acb3bec 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) +DEFAULT_MULTI(Solve) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 925f4731c3..55f2780f24 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -51,6 +51,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/solve.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index f8932c5f8e..f6968d081c 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -114,6 +114,7 @@ DEFAULT(Tanh) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) +DEFAULT_MULTI(Solve) namespace { diff --git a/mlx/backend/common/solve.cpp b/mlx/backend/common/solve.cpp new file mode 100644 index 0000000000..be0f190d84 --- /dev/null +++ b/mlx/backend/common/solve.cpp @@ -0,0 +1,131 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack_helper.h" +#include "mlx/primitives.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +#include + +namespace mlx::core { + +namespace { + +// Wrapper to account for differences in +// LAPACK implementations (basically how to pass the 'trans' string to fortran). +int sgetrs_wrapper(char trans, int N, int NRHS, int* ipiv, float* a, float* b) { + int info; + +#ifdef LAPACK_FORTRAN_STRLEN_END + sgetrs_( + /* trans */ &trans, + /* n */ &N, + /* nrhs */ &NRHS, + /* a */ a, + /* lda */ &N, + /* ipiv */ ipiv, + /* b */ b, + /* ldb */ &N, + /* info */ &info, + /* trans_len = */ static_cast(1)); +#else + sgetrs_( + /* trans */ &trans, + /* n */ &N, + /* nrhs */ &NRHS, + /* a */ a, + /* lda */ &N, + /* ipiv */ ipiv, + /* b */ b, + /* ldb */ &N, + /* info */ &info); +#endif + + return info; +} + +} // namespace + +void solve_impl(const array& a, const array& b, array& out) { + int N = a.shape(-2); + int NRHS = out.shape(-1); + std::vector ipiv(N); + + // copy b into out and make it col-contiguous + auto flags = out.flags(); + flags.col_contiguous = true; + flags.row_contiguous = false; + std::vector strides(a.ndim(), 0); + std::copy(out.strides().begin(), out.strides().end(), strides.begin()); + strides[a.ndim() - 2] = 1; + strides[a.ndim() - 1] = N; + + out.set_data( + allocator::malloc_or_wait(out.nbytes()), out.nbytes(), strides, flags); + copy_inplace(b, out, CopyType::GeneralGeneral); + + // lapack clobbers the input, so we have to make a copy. the copy doesn't need + // to be col-contiguous because sgetrs has a transpose parameter (trans='T'). + array a_cpy(a.shape(), float32, nullptr, {}); + copy( + a, + a_cpy, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + + float* a_ptr = a_cpy.data(); + float* out_ptr = out.data(); + int* ipiv_ptr = ipiv.data(); + + int info; + size_t num_matrices = a.size() / (N * N); + for (size_t i = 0; i < num_matrices; i++) { + // Compute LU factorization of A + MLX_LAPACK_FUNC(sgetrf) + (/* m */ &N, + /* n */ &N, + /* a */ a_ptr, + /* lda */ &N, + /* ipiv */ ipiv_ptr, + /* info */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "solve_impl: sgetrf_ failed with code " << info + << ((info > 0) ? " because matrix is singular" + : " becuase argument had an illegal value"); + throw std::runtime_error(ss.str()); + } + + static constexpr char trans = 'T'; + // Solve the system using the LU factors from sgetrf + info = sgetrs_wrapper(trans, N, NRHS, ipiv_ptr, a_ptr, out_ptr); + + if (info != 0) { + std::stringstream ss; + ss << "solve_impl: sgetrs_ failed with code " << info; + throw std::runtime_error(ss.str()); + } + + // Advance pointers to the next matrix + a_ptr += N * N; + out_ptr += N * NRHS; + } +} + +void Solve::eval( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 2); + if (inputs[0].dtype() != float32 || inputs[1].dtype() != float32) { + throw std::runtime_error("[Solve::eval] only supports float32."); + } + solve_impl(inputs[0], inputs[1], outputs[0]); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 31f2248d72..c5a5615920 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -432,4 +432,10 @@ void View::eval_gpu(const std::vector& inputs, array& out) { } } +void Solve::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[Solve::eval_gpu] Metal Solve NYI."); +} + } // namespace mlx::core diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index ff60e4d22a..eb1a3ace11 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -108,5 +108,6 @@ NO_CPU(Tanh) NO_CPU(Transpose) NO_CPU(Inverse) NO_CPU(View) +NO_CPU_MULTI(Solve) } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 544a2c6f2b..0ea0eee804 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -111,6 +111,7 @@ NO_GPU(Transpose) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU(View) +NO_GPU_MULTI(Solve) namespace fast { NO_GPU_MULTI(LayerNorm) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index a64f98aa8f..e14b148e03 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -454,4 +454,47 @@ array cross( return concatenate(outputs, axis, s); } +array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { + if (a.dtype() != float32 && b.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::solve] Input array must have type float32. Received array " + << "with type " << a.dtype() << " and " << b.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::solve] Arrays must have >= 2 dimensions. Received array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (b.ndim() < 1) { + std::ostringstream msg; + msg << "[linalg::solve] Array must have >= 1 dimension. Received array with " + << b.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + std::ostringstream msg; + msg << "[linalg::solve] First input must be a square matrix. Received array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != b.shape(b.ndim() - 2)) { + std::ostringstream msg; + msg << "[linalg::solve] Last dimension of first input with shape " + << a.shape() << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + auto out_type = promote_types(a.dtype(), b.dtype()); + + return array( + b.shape(), out_type, std::make_shared(to_stream(s)), {a, b}); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index acfcc1a415..a2b4b811a8 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -74,6 +74,8 @@ array pinv(const array& a, StreamOrDevice s = {}); array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); +array solve(const array& a, const array& b, StreamOrDevice s = {}); + /** * Compute the cross product of two arrays along the given axis. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index c28a945a39..716ec20891 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4110,6 +4110,17 @@ std::pair, std::vector> SVD::vmap( return {{linalg::svd(a, stream())}, {ax, ax, ax}}; } +std::pair, std::vector> Solve::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto maybe_move_ax = [this](auto& arr, auto ax) { + return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr; + }; + auto a = maybe_move_ax(inputs[0], axes[0]); + auto b = maybe_move_ax(inputs[1], axes[1]); + return {{linalg::solve(a, b, stream())}, {0}}; +} + std::pair, std::vector> Inverse::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 810eb50963..fe504d040d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2168,4 +2168,20 @@ class Cholesky : public UnaryPrimitive { bool upper_; }; +class Solve : public Primitive { + public: + explicit Solve(Stream stream) : Primitive(stream) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_PRINT(Solve) + + private: + void eval(const std::vector& inputs, std::vector& outputs); +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 65dd8d0e4e..60e838955d 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -405,4 +405,25 @@ void init_linalg(nb::module_& parent_module) { Returns: array: The cross product of ``a`` and ``b`` along the specified axis. )pbdoc"); + m.def( + "solve", + &solve, + "a"_a, + "b"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def solve(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the solution to a square system of linear equations AX = B. + + Args: + a (array): Input array. + b (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The unique solution to the system AX = B. + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 6051beef7b..5a5b2e111a 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -268,6 +268,60 @@ def test_cross_product(self): with self.assertRaises(ValueError): mx.linalg.cross(a, b) + def test_solve(self): + mx.random.seed(7) + + # Test 3x3 matrix with 1D rhs + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]) + b = mx.array([11.0, 35.0, 28.0]) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test symmetric positive-definite matrix + N = 5 + a = mx.random.uniform(shape=(N, N)) + a = mx.matmul(a, a.T) + N * mx.eye(N) + b = mx.random.uniform(shape=(N, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + + # Test batch dimension + a = mx.random.uniform(shape=(5, 5, 4, 4)) + b = mx.random.uniform(shape=(5, 5, 4, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + + # Test large matrix + N = 1000 + a = mx.random.uniform(shape=(N, N)) + b = mx.random.uniform(shape=(N, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-2)) + + # Test multi-column rhs + a = mx.random.uniform(shape=(5, 5)) + b = mx.random.uniform(shape=(5, 8)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + + # Test batched multi-column rhs + a = mx.concat([a, a, a, a, a, a]).reshape((3, 2, 5, 5)) + b = mx.concat([b, b, b, b, b, b]).reshape((3, 2, 5, 8)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index e9e1965837..57d7d09c72 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -435,3 +435,53 @@ TEST_CASE("test cross product") { result = cross(a, b); CHECK(allclose(result, expected).item()); } + +TEST_CASE("test solve") { + // 0D and 1D throw + CHECK_THROWS(linalg::solve(array(0.), array(0.), Device::cpu)); + CHECK_THROWS(linalg::solve(array({0.}), array({0.}), Device::cpu)); + + // Unsupported types throw + CHECK_THROWS( + linalg::solve(array({0, 1, 1, 2}, {2, 2}), array({1, 3}), Device::cpu)); + + // Non-square throws + array a = reshape(arange(6), {3, 2}); + array b = reshape(arange(3), {3, 1}); + CHECK_THROWS(linalg::solve(a, b, Device::cpu)); + + // Test 2x2 matrix with 1D rhs + a = array({2., 1., 1., 3.}, {2, 2}); + b = array({8., 13.}, {2}); + + array result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test 3x3 matrix + a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3}); + b = array({6., 15., 25.}, {3, 1}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test batch dimension + a = reshape(concatenate({a, a, a, a, a}), {5, 3, 3}); + b = reshape(concatenate({b, b, b, b, b}), {5, 3, 1}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test multi-column rhs + a = array({2., 1., 1., 1., 3., 2., 1., 0., 0.}, {3, 3}); + b = array({4., 2., 5., 3., 6., 1.}, {3, 2}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test batch multi-column rhs + a = reshape(concatenate({a, a, a, a, a}), {5, 3, 3}); + b = reshape(concatenate({b, b, b, b, b}), {5, 3, 2}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); +} \ No newline at end of file