diff --git a/python/src/array.cpp b/python/src/array.cpp index 017fb6e91..9d871af33 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -22,56 +22,56 @@ #include "mlx/transforms.h" #include "mlx/utils.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; class ArrayAt { public: - ArrayAt(array x) : x_(std::move(x)) {} + ArrayAt(mx::array x) : x_(std::move(x)) {} ArrayAt& set_indices(nb::object indices) { indices_ = indices; return *this; } - array add(const ScalarOrArray& v) { + mx::array add(const ScalarOrArray& v) { return mlx_add_item(x_, indices_, v); } - array subtract(const ScalarOrArray& v) { + mx::array subtract(const ScalarOrArray& v) { return mlx_subtract_item(x_, indices_, v); } - array multiply(const ScalarOrArray& v) { + mx::array multiply(const ScalarOrArray& v) { return mlx_multiply_item(x_, indices_, v); } - array divide(const ScalarOrArray& v) { + mx::array divide(const ScalarOrArray& v) { return mlx_divide_item(x_, indices_, v); } - array maximum(const ScalarOrArray& v) { + mx::array maximum(const ScalarOrArray& v) { return mlx_maximum_item(x_, indices_, v); } - array minimum(const ScalarOrArray& v) { + mx::array minimum(const ScalarOrArray& v) { return mlx_minimum_item(x_, indices_, v); } private: - array x_; + mx::array x_; nb::object indices_; }; class ArrayPythonIterator { public: - ArrayPythonIterator(array x) : idx_(0), x_(std::move(x)) { + ArrayPythonIterator(mx::array x) : idx_(0), x_(std::move(x)) { if (x_.shape(0) > 0 && x_.shape(0) < 10) { - splits_ = split(x_, x_.shape(0)); + splits_ = mx::split(x_, x_.shape(0)); } } - array next() { + mx::array next() { if (idx_ >= x_.shape(0)) { throw nb::stop_iteration(); } if (idx_ >= 0 && idx_ < splits_.size()) { - return squeeze(splits_[idx_++], 0); + return mx::squeeze(splits_[idx_++], 0); } return *(x_.begin() + idx_++); @@ -79,16 +79,16 @@ class ArrayPythonIterator { private: int idx_; - array x_; - std::vector splits_; + mx::array x_; + std::vector splits_; }; void init_array(nb::module_& m) { // Set Python print formatting options - get_global_formatter().capitalize_bool = true; + mx::get_global_formatter().capitalize_bool = true; // Types - nb::class_( + nb::class_( m, "Dtype", R"pbdoc( @@ -98,10 +98,10 @@ void init_array(nb::module_& m) { on available data types. )pbdoc") .def_prop_ro( - "size", &Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") + "size", &mx::Dtype::size, R"pbdoc(Size of the type in bytes.)pbdoc") .def( "__repr__", - [](const Dtype& t) { + [](const mx::Dtype& t) { std::ostringstream os; os << "mlx.core."; os << t; @@ -109,27 +109,28 @@ void init_array(nb::module_& m) { }) .def( "__eq__", - [](const Dtype& t, const nb::object& other) { - return nb::isinstance(other) && t == nb::cast(other); + [](const mx::Dtype& t, const nb::object& other) { + return nb::isinstance(other) && + t == nb::cast(other); }) - .def("__hash__", [](const Dtype& t) { + .def("__hash__", [](const mx::Dtype& t) { return static_cast(t.val()); }); - m.attr("bool_") = nb::cast(bool_); - m.attr("uint8") = nb::cast(uint8); - m.attr("uint16") = nb::cast(uint16); - m.attr("uint32") = nb::cast(uint32); - m.attr("uint64") = nb::cast(uint64); - m.attr("int8") = nb::cast(int8); - m.attr("int16") = nb::cast(int16); - m.attr("int32") = nb::cast(int32); - m.attr("int64") = nb::cast(int64); - m.attr("float16") = nb::cast(float16); - m.attr("float32") = nb::cast(float32); - m.attr("bfloat16") = nb::cast(bfloat16); - m.attr("complex64") = nb::cast(complex64); - nb::enum_( + m.attr("bool_") = nb::cast(mx::bool_); + m.attr("uint8") = nb::cast(mx::uint8); + m.attr("uint16") = nb::cast(mx::uint16); + m.attr("uint32") = nb::cast(mx::uint32); + m.attr("uint64") = nb::cast(mx::uint64); + m.attr("int8") = nb::cast(mx::int8); + m.attr("int16") = nb::cast(mx::int16); + m.attr("int32") = nb::cast(mx::int32); + m.attr("int64") = nb::cast(mx::int64); + m.attr("float16") = nb::cast(mx::float16); + m.attr("float32") = nb::cast(mx::float32); + m.attr("bfloat16") = nb::cast(mx::bfloat16); + m.attr("complex64") = nb::cast(mx::complex64); + nb::enum_( m, "DtypeCategory", R"pbdoc( @@ -169,14 +170,14 @@ void init_array(nb::module_& m) { See also :func:`~mlx.core.issubdtype`. )pbdoc") - .value("complexfloating", complexfloating) - .value("floating", floating) - .value("inexact", inexact) - .value("signedinteger", signedinteger) - .value("unsignedinteger", unsignedinteger) - .value("integer", integer) - .value("number", number) - .value("generic", generic) + .value("complexfloating", mx::complexfloating) + .value("floating", mx::floating) + .value("inexact", mx::inexact) + .value("signedinteger", mx::signedinteger) + .value("unsignedinteger", mx::unsignedinteger) + .value("integer", mx::integer) + .value("number", mx::number) + .value("generic", mx::generic) .export_values(); nb::class_( m, @@ -207,7 +208,7 @@ void init_array(nb::module_& m) { {Py_bf_releasebuffer, (void*)releasebuffer}, {0, nullptr}}; - nb::class_( + nb::class_( m, "array", R"pbdoc(An N-dimensional array object.)pbdoc", @@ -215,27 +216,30 @@ void init_array(nb::module_& m) { nb::is_weak_referenceable()) .def( "__init__", - [](array* aptr, ArrayInitType v, std::optional t) { - new (aptr) array(create_array(v, t)); + [](mx::array* aptr, ArrayInitType v, std::optional t) { + new (aptr) mx::array(create_array(v, t)); }, "val"_a, "dtype"_a = nb::none(), nb::sig( "def __init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None)")) .def_prop_ro( - "size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc") - .def_prop_ro("ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc") + "size", + &mx::array::size, + R"pbdoc(Number of elements in the array.)pbdoc") + .def_prop_ro( + "ndim", &mx::array::ndim, R"pbdoc(The array's dimension.)pbdoc") .def_prop_ro( "itemsize", - &array::itemsize, + &mx::array::itemsize, R"pbdoc(The size of the array's datatype in bytes.)pbdoc") .def_prop_ro( "nbytes", - &array::nbytes, + &mx::array::nbytes, R"pbdoc(The number of bytes in the array.)pbdoc") .def_prop_ro( "shape", - [](const array& a) { return nb::tuple(nb::cast(a.shape())); }, + [](const mx::array& a) { return nb::tuple(nb::cast(a.shape())); }, R"pbdoc( The shape of the array as a Python tuple. @@ -244,7 +248,7 @@ void init_array(nb::module_& m) { )pbdoc") .def_prop_ro( "dtype", - &array::dtype, + &mx::array::dtype, R"pbdoc( The array's :class:`Dtype`. )pbdoc") @@ -276,7 +280,7 @@ void init_array(nb::module_& m) { )pbdoc") .def( "astype", - &astype, + &mx::astype, "dtype"_a, "stream"_a = nb::none(), R"pbdoc( @@ -291,7 +295,8 @@ void init_array(nb::module_& m) { )pbdoc") .def( "__array_namespace__", - [](const array& a, const std::optional& api_version) { + [](const mx::array& a, + const std::optional& api_version) { if (api_version) { throw std::invalid_argument( "Explicitly specifying api_version is not yet implemented."); @@ -316,7 +321,7 @@ void init_array(nb::module_& m) { .def("__setitem__", mlx_set_item, nb::arg().none(), nb::arg()) .def_prop_ro( "at", - [](const array& a) { return ArrayAt(a); }, + [](const mx::array& a) { return ArrayAt(a); }, R"pbdoc( Used to apply updates at the given indices. @@ -358,25 +363,26 @@ void init_array(nb::module_& m) { )pbdoc") .def( "__len__", - [](const array& a) { + [](const mx::array& a) { if (a.ndim() == 0) { throw nb::type_error("len() 0-dimensional array."); } return a.shape(0); }) - .def("__iter__", [](const array& a) { return ArrayPythonIterator(a); }) + .def( + "__iter__", [](const mx::array& a) { return ArrayPythonIterator(a); }) .def("__getstate__", &mlx_to_np_array) .def( "__setstate__", - [](array& arr, + [](mx::array& arr, const nb::ndarray& state) { - new (&arr) array(nd_array_to_mlx(state, std::nullopt)); + new (&arr) mx::array(nd_array_to_mlx(state, std::nullopt)); }) - .def("__dlpack__", [](const array& a) { return mlx_to_dlpack(a); }) + .def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); }) .def( "__dlpack_device__", - [](const array& a) { - if (metal::is_available()) { + [](const mx::array& a) { + if (mx::metal::is_available()) { // Metal device is available return nb::make_tuple(8, 0); } else { @@ -384,115 +390,115 @@ void init_array(nb::module_& m) { return nb::make_tuple(1, 0); } }) - .def("__copy__", [](const array& self) { return array(self); }) + .def("__copy__", [](const mx::array& self) { return mx::array(self); }) .def( "__deepcopy__", - [](const array& self, nb::dict) { return array(self); }, + [](const mx::array& self, nb::dict) { return mx::array(self); }, "memo"_a) .def( "__add__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("addition", v); } auto b = to_array(v, a.dtype()); - return add(a, b); + return mx::add(a, b); }, "other"_a) .def( "__iadd__", - [](array& a, const ScalarOrArray v) -> array& { + [](mx::array& a, const ScalarOrArray v) -> mx::array& { if (!is_comparable_with_array(v)) { throw_invalid_operation("inplace addition", v); } - a.overwrite_descriptor(add(a, to_array(v, a.dtype()))); + a.overwrite_descriptor(mx::add(a, to_array(v, a.dtype()))); return a; }, "other"_a, nb::rv_policy::none) .def( "__radd__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("addition", v); } - return add(a, to_array(v, a.dtype())); + return mx::add(a, to_array(v, a.dtype())); }, "other"_a) .def( "__sub__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("subtraction", v); } - return subtract(a, to_array(v, a.dtype())); + return mx::subtract(a, to_array(v, a.dtype())); }, "other"_a) .def( "__isub__", - [](array& a, const ScalarOrArray v) -> array& { + [](mx::array& a, const ScalarOrArray v) -> mx::array& { if (!is_comparable_with_array(v)) { throw_invalid_operation("inplace subtraction", v); } - a.overwrite_descriptor(subtract(a, to_array(v, a.dtype()))); + a.overwrite_descriptor(mx::subtract(a, to_array(v, a.dtype()))); return a; }, "other"_a, nb::rv_policy::none) .def( "__rsub__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("subtraction", v); } - return subtract(to_array(v, a.dtype()), a); + return mx::subtract(to_array(v, a.dtype()), a); }, "other"_a) .def( "__mul__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("multiplication", v); } - return multiply(a, to_array(v, a.dtype())); + return mx::multiply(a, to_array(v, a.dtype())); }, "other"_a) .def( "__imul__", - [](array& a, const ScalarOrArray v) -> array& { + [](mx::array& a, const ScalarOrArray v) -> mx::array& { if (!is_comparable_with_array(v)) { throw_invalid_operation("inplace multiplication", v); } - a.overwrite_descriptor(multiply(a, to_array(v, a.dtype()))); + a.overwrite_descriptor(mx::multiply(a, to_array(v, a.dtype()))); return a; }, "other"_a, nb::rv_policy::none) .def( "__rmul__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("multiplication", v); } - return multiply(a, to_array(v, a.dtype())); + return mx::multiply(a, to_array(v, a.dtype())); }, "other"_a) .def( "__truediv__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("division", v); } - return divide(a, to_array(v, a.dtype())); + return mx::divide(a, to_array(v, a.dtype())); }, "other"_a) .def( "__itruediv__", - [](array& a, const ScalarOrArray v) -> array& { + [](mx::array& a, const ScalarOrArray v) -> mx::array& { if (!is_comparable_with_array(v)) { throw_invalid_operation("inplace division", v); } - if (!issubdtype(a.dtype(), inexact)) { + if (!mx::issubdtype(a.dtype(), mx::inexact)) { throw std::invalid_argument( "In place division cannot cast to non-floating point type."); } @@ -503,151 +509,151 @@ void init_array(nb::module_& m) { nb::rv_policy::none) .def( "__rtruediv__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("division", v); } - return divide(to_array(v, a.dtype()), a); + return mx::divide(to_array(v, a.dtype()), a); }, "other"_a) .def( "__div__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("division", v); } - return divide(a, to_array(v, a.dtype())); + return mx::divide(a, to_array(v, a.dtype())); }, "other"_a) .def( "__rdiv__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("division", v); } - return divide(to_array(v, a.dtype()), a); + return mx::divide(to_array(v, a.dtype()), a); }, "other"_a) .def( "__floordiv__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("floor division", v); } - return floor_divide(a, to_array(v, a.dtype())); + return mx::floor_divide(a, to_array(v, a.dtype())); }, "other"_a) .def( "__ifloordiv__", - [](array& a, const ScalarOrArray v) -> array& { + [](mx::array& a, const ScalarOrArray v) -> mx::array& { if (!is_comparable_with_array(v)) { throw_invalid_operation("inplace floor division", v); } - a.overwrite_descriptor(floor_divide(a, to_array(v, a.dtype()))); + a.overwrite_descriptor(mx::floor_divide(a, to_array(v, a.dtype()))); return a; }, "other"_a, nb::rv_policy::none) .def( "__rfloordiv__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("floor division", v); } auto b = to_array(v, a.dtype()); - return floor_divide(b, a); + return mx::floor_divide(b, a); }, "other"_a) .def( "__mod__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("modulus", v); } - return remainder(a, to_array(v, a.dtype())); + return mx::remainder(a, to_array(v, a.dtype())); }, "other"_a) .def( "__imod__", - [](array& a, const ScalarOrArray v) -> array& { + [](mx::array& a, const ScalarOrArray v) -> mx::array& { if (!is_comparable_with_array(v)) { throw_invalid_operation("inplace modulus", v); } - a.overwrite_descriptor(remainder(a, to_array(v, a.dtype()))); + a.overwrite_descriptor(mx::remainder(a, to_array(v, a.dtype()))); return a; }, "other"_a, nb::rv_policy::none) .def( "__rmod__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("modulus", v); } - return remainder(to_array(v, a.dtype()), a); + return mx::remainder(to_array(v, a.dtype()), a); }, "other"_a) .def( "__eq__", - [](const array& a, - const ScalarOrArray& v) -> std::variant { + [](const mx::array& a, + const ScalarOrArray& v) -> std::variant { if (!is_comparable_with_array(v)) { return false; } - return equal(a, to_array(v, a.dtype())); + return mx::equal(a, to_array(v, a.dtype())); }, "other"_a) .def( "__lt__", - [](const array& a, const ScalarOrArray v) -> array { + [](const mx::array& a, const ScalarOrArray v) -> mx::array { if (!is_comparable_with_array(v)) { throw_invalid_operation("less than", v); } - return less(a, to_array(v, a.dtype())); + return mx::less(a, to_array(v, a.dtype())); }, "other"_a) .def( "__le__", - [](const array& a, const ScalarOrArray v) -> array { + [](const mx::array& a, const ScalarOrArray v) -> mx::array { if (!is_comparable_with_array(v)) { throw_invalid_operation("less than or equal", v); } - return less_equal(a, to_array(v, a.dtype())); + return mx::less_equal(a, to_array(v, a.dtype())); }, "other"_a) .def( "__gt__", - [](const array& a, const ScalarOrArray v) -> array { + [](const mx::array& a, const ScalarOrArray v) -> mx::array { if (!is_comparable_with_array(v)) { throw_invalid_operation("greater than", v); } - return greater(a, to_array(v, a.dtype())); + return mx::greater(a, to_array(v, a.dtype())); }, "other"_a) .def( "__ge__", - [](const array& a, const ScalarOrArray v) -> array { + [](const mx::array& a, const ScalarOrArray v) -> mx::array { if (!is_comparable_with_array(v)) { throw_invalid_operation("greater than or equal", v); } - return greater_equal(a, to_array(v, a.dtype())); + return mx::greater_equal(a, to_array(v, a.dtype())); }, "other"_a) .def( "__ne__", - [](const array& a, - const ScalarOrArray v) -> std::variant { + [](const mx::array& a, + const ScalarOrArray v) -> std::variant { if (!is_comparable_with_array(v)) { return true; } - return not_equal(a, to_array(v, a.dtype())); + return mx::not_equal(a, to_array(v, a.dtype())); }, "other"_a) - .def("__neg__", [](const array& a) { return -a; }) - .def("__bool__", [](array& a) { return nb::bool_(to_scalar(a)); }) + .def("__neg__", [](const mx::array& a) { return -a; }) + .def("__bool__", [](mx::array& a) { return nb::bool_(to_scalar(a)); }) .def( "__repr__", - [](array& a) { + [](mx::array& a) { nb::gil_scoped_release nogil; std::ostringstream os; os << a; @@ -655,191 +661,193 @@ void init_array(nb::module_& m) { }) .def( "__matmul__", - [](const array& a, array& other) { return matmul(a, other); }, + [](const mx::array& a, mx::array& other) { + return mx::matmul(a, other); + }, "other"_a) .def( "__imatmul__", - [](array& a, array& other) -> array& { - a.overwrite_descriptor(matmul(a, other)); + [](mx::array& a, mx::array& other) -> mx::array& { + a.overwrite_descriptor(mx::matmul(a, other)); return a; }, "other"_a, nb::rv_policy::none) .def( "__pow__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("power", v); } - return power(a, to_array(v, a.dtype())); + return mx::power(a, to_array(v, a.dtype())); }, "other"_a) .def( "__rpow__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("power", v); } - return power(to_array(v, a.dtype()), a); + return mx::power(to_array(v, a.dtype()), a); }, "other"_a) .def( "__ipow__", - [](array& a, const ScalarOrArray v) -> array& { + [](mx::array& a, const ScalarOrArray v) -> mx::array& { if (!is_comparable_with_array(v)) { throw_invalid_operation("inplace power", v); } - a.overwrite_descriptor(power(a, to_array(v, a.dtype()))); + a.overwrite_descriptor(mx::power(a, to_array(v, a.dtype()))); return a; }, "other"_a, nb::rv_policy::none) .def( "__invert__", - [](const array& a) { - if (issubdtype(a.dtype(), inexact)) { + [](const mx::array& a) { + if (mx::issubdtype(a.dtype(), mx::inexact)) { throw std::invalid_argument( "Floating point types not allowed with or bitwise inversion."); } - if (a.dtype() != bool_) { + if (a.dtype() != mx::bool_) { throw std::invalid_argument( "Bitwise inversion not yet supported for integer types."); } - return logical_not(a); + return mx::logical_not(a); }) .def( "__and__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("bitwise and", v); } auto b = to_array(v, a.dtype()); - if (issubdtype(a.dtype(), inexact) || - issubdtype(b.dtype(), inexact)) { + if (mx::issubdtype(a.dtype(), mx::inexact) || + mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( "Floating point types not allowed with bitwise and."); } - return bitwise_and(a, b); + return mx::bitwise_and(a, b); }, "other"_a) .def( "__iand__", - [](array& a, const ScalarOrArray v) -> array& { + [](mx::array& a, const ScalarOrArray v) -> mx::array& { if (!is_comparable_with_array(v)) { throw_invalid_operation("inplace bitwise and", v); } auto b = to_array(v, a.dtype()); - if (issubdtype(a.dtype(), inexact) || - issubdtype(b.dtype(), inexact)) { + if (mx::issubdtype(a.dtype(), mx::inexact) || + mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( "Floating point types not allowed with bitwise and."); } - a.overwrite_descriptor(bitwise_and(a, b)); + a.overwrite_descriptor(mx::bitwise_and(a, b)); return a; }, "other"_a, nb::rv_policy::none) .def( "__or__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("bitwise or", v); } auto b = to_array(v, a.dtype()); - if (issubdtype(a.dtype(), inexact) || - issubdtype(b.dtype(), inexact)) { + if (mx::issubdtype(a.dtype(), mx::inexact) || + mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( "Floating point types not allowed with or bitwise or."); } - return bitwise_or(a, b); + return mx::bitwise_or(a, b); }, "other"_a) .def( "__ior__", - [](array& a, const ScalarOrArray v) -> array& { + [](mx::array& a, const ScalarOrArray v) -> mx::array& { if (!is_comparable_with_array(v)) { throw_invalid_operation("inplace bitwise or", v); } auto b = to_array(v, a.dtype()); - if (issubdtype(a.dtype(), inexact) || - issubdtype(b.dtype(), inexact)) { + if (mx::issubdtype(a.dtype(), mx::inexact) || + mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( "Floating point types not allowed with or bitwise or."); } - a.overwrite_descriptor(bitwise_or(a, b)); + a.overwrite_descriptor(mx::bitwise_or(a, b)); return a; }, "other"_a, nb::rv_policy::none) .def( "__lshift__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("left shift", v); } auto b = to_array(v, a.dtype()); - if (issubdtype(a.dtype(), inexact) || - issubdtype(b.dtype(), inexact)) { + if (mx::issubdtype(a.dtype(), mx::inexact) || + mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( "Floating point types not allowed with left shift."); } - return left_shift(a, b); + return mx::left_shift(a, b); }, "other"_a) .def( "__ilshift__", - [](array& a, const ScalarOrArray v) -> array& { + [](mx::array& a, const ScalarOrArray v) -> mx::array& { if (!is_comparable_with_array(v)) { throw_invalid_operation("inplace left shift", v); } auto b = to_array(v, a.dtype()); - if (issubdtype(a.dtype(), inexact) || - issubdtype(b.dtype(), inexact)) { + if (mx::issubdtype(a.dtype(), mx::inexact) || + mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( "Floating point types not allowed with or left shift."); } - a.overwrite_descriptor(left_shift(a, b)); + a.overwrite_descriptor(mx::left_shift(a, b)); return a; }, "other"_a, nb::rv_policy::none) .def( "__rshift__", - [](const array& a, const ScalarOrArray v) { + [](const mx::array& a, const ScalarOrArray v) { if (!is_comparable_with_array(v)) { throw_invalid_operation("right shift", v); } auto b = to_array(v, a.dtype()); - if (issubdtype(a.dtype(), inexact) || - issubdtype(b.dtype(), inexact)) { + if (mx::issubdtype(a.dtype(), mx::inexact) || + mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( "Floating point types not allowed with right shift."); } - return right_shift(a, b); + return mx::right_shift(a, b); }, "other"_a) .def( "__irshift__", - [](array& a, const ScalarOrArray v) -> array& { + [](mx::array& a, const ScalarOrArray v) -> mx::array& { if (!is_comparable_with_array(v)) { throw_invalid_operation("inplace right shift", v); } auto b = to_array(v, a.dtype()); - if (issubdtype(a.dtype(), inexact) || - issubdtype(b.dtype(), inexact)) { + if (mx::issubdtype(a.dtype(), mx::inexact) || + mx::issubdtype(b.dtype(), mx::inexact)) { throw std::invalid_argument( "Floating point types not allowed with or right shift."); } - a.overwrite_descriptor(right_shift(a, b)); + a.overwrite_descriptor(mx::right_shift(a, b)); return a; }, "other"_a, nb::rv_policy::none) - .def("__int__", [](array& a) { return nb::int_(to_scalar(a)); }) - .def("__float__", [](array& a) { return nb::float_(to_scalar(a)); }) + .def("__int__", [](mx::array& a) { return nb::int_(to_scalar(a)); }) + .def("__float__", [](mx::array& a) { return nb::float_(to_scalar(a)); }) .def( "__format__", - [](array& a, nb::object format_spec) { + [](mx::array& a, nb::object format_spec) { if (nb::len(nb::str(format_spec)) > 0 && a.ndim() > 0) { throw nb::type_error( "unsupported format string passed to mx.array.__format__"); @@ -856,11 +864,11 @@ void init_array(nb::module_& m) { }) .def( "flatten", - [](const array& a, + [](const mx::array& a, int start_axis, int end_axis, - const StreamOrDevice& s) { - return flatten(a, start_axis, end_axis, s); + const mx::StreamOrDevice& s) { + return mx::flatten(a, start_axis, end_axis, s); }, "start_axis"_a = 0, "end_axis"_a = -1, @@ -871,14 +879,14 @@ void init_array(nb::module_& m) { )pbdoc") .def( "reshape", - [](const array& a, nb::args shape_, StreamOrDevice s) { + [](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) { std::vector shape; if (!nb::isinstance(shape_[0])) { shape = nb::cast>(shape_[0]); } else { shape = nb::cast>(shape_); } - return reshape(a, shape, s); + return mx::reshape(a, shape, s); }, "shape"_a, "stream"_a = nb::none(), @@ -890,13 +898,15 @@ void init_array(nb::module_& m) { )pbdoc") .def( "squeeze", - [](const array& a, const IntOrVec& v, const StreamOrDevice& s) { + [](const mx::array& a, + const IntOrVec& v, + const mx::StreamOrDevice& s) { if (std::holds_alternative(v)) { - return squeeze(a, s); + return mx::squeeze(a, s); } else if (auto pv = std::get_if(&v); pv) { - return squeeze(a, *pv, s); + return mx::squeeze(a, *pv, s); } else { - return squeeze(a, std::get>(v), s); + return mx::squeeze(a, std::get>(v), s); } }, "axis"_a = nb::none(), @@ -907,85 +917,87 @@ void init_array(nb::module_& m) { )pbdoc") .def( "abs", - &mlx::core::abs, + &mx::abs, nb::kw_only(), "stream"_a = nb::none(), "See :func:`abs`.") .def( - "__abs__", [](const array& a) { return abs(a); }, "See :func:`abs`.") + "__abs__", + [](const mx::array& a) { return mx::abs(a); }, + "See :func:`abs`.") .def( "square", - &square, + &mx::square, nb::kw_only(), "stream"_a = nb::none(), "See :func:`square`.") .def( "sqrt", - &mlx::core::sqrt, + &mx::sqrt, nb::kw_only(), "stream"_a = nb::none(), "See :func:`sqrt`.") .def( "rsqrt", - &rsqrt, + &mx::rsqrt, nb::kw_only(), "stream"_a = nb::none(), "See :func:`rsqrt`.") .def( "reciprocal", - &reciprocal, + &mx::reciprocal, nb::kw_only(), "stream"_a = nb::none(), "See :func:`reciprocal`.") .def( "exp", - &mlx::core::exp, + &mx::exp, nb::kw_only(), "stream"_a = nb::none(), "See :func:`exp`.") .def( "log", - &mlx::core::log, + &mx::log, nb::kw_only(), "stream"_a = nb::none(), "See :func:`log`.") .def( "log2", - &mlx::core::log2, + &mx::log2, nb::kw_only(), "stream"_a = nb::none(), "See :func:`log2`.") .def( "log10", - &mlx::core::log10, + &mx::log10, nb::kw_only(), "stream"_a = nb::none(), "See :func:`log10`.") .def( "sin", - &mlx::core::sin, + &mx::sin, nb::kw_only(), "stream"_a = nb::none(), "See :func:`sin`.") .def( "cos", - &mlx::core::cos, + &mx::cos, nb::kw_only(), "stream"_a = nb::none(), "See :func:`cos`.") .def( "log1p", - &mlx::core::log1p, + &mx::log1p, nb::kw_only(), "stream"_a = nb::none(), "See :func:`log1p`.") .def( "all", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = nb::none(), "keepdims"_a = false, @@ -994,11 +1006,11 @@ void init_array(nb::module_& m) { "See :func:`all`.") .def( "any", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = nb::none(), "keepdims"_a = false, @@ -1007,7 +1019,7 @@ void init_array(nb::module_& m) { "See :func:`any`.") .def( "moveaxis", - &moveaxis, + &mx::moveaxis, "source"_a, "destination"_a, nb::kw_only(), @@ -1015,7 +1027,7 @@ void init_array(nb::module_& m) { "See :func:`moveaxis`.") .def( "swapaxes", - &swapaxes, + &mx::swapaxes, "axis1"_a, "axis2"_a, nb::kw_only(), @@ -1023,9 +1035,9 @@ void init_array(nb::module_& m) { "See :func:`swapaxes`.") .def( "transpose", - [](const array& a, nb::args axes_, StreamOrDevice s) { + [](const mx::array& a, nb::args axes_, mx::StreamOrDevice s) { if (axes_.size() == 0) { - return transpose(a, s); + return mx::transpose(a, s); } std::vector axes; if (!nb::isinstance(axes_[0])) { @@ -1033,7 +1045,7 @@ void init_array(nb::module_& m) { } else { axes = nb::cast>(axes_); } - return transpose(a, axes, s); + return mx::transpose(a, axes, s); }, "axes"_a, "stream"_a = nb::none(), @@ -1045,15 +1057,15 @@ void init_array(nb::module_& m) { )pbdoc") .def_prop_ro( "T", - [](const array& a) { return transpose(a); }, + [](const mx::array& a) { return mx::transpose(a); }, "Equivalent to calling ``self.transpose()`` with no arguments.") .def( "sum", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = nb::none(), "keepdims"_a = false, @@ -1062,11 +1074,11 @@ void init_array(nb::module_& m) { "See :func:`sum`.") .def( "prod", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = nb::none(), "keepdims"_a = false, @@ -1075,11 +1087,11 @@ void init_array(nb::module_& m) { "See :func:`prod`.") .def( "min", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = nb::none(), "keepdims"_a = false, @@ -1088,11 +1100,11 @@ void init_array(nb::module_& m) { "See :func:`min`.") .def( "max", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = nb::none(), "keepdims"_a = false, @@ -1101,11 +1113,12 @@ void init_array(nb::module_& m) { "See :func:`max`.") .def( "logsumexp", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::logsumexp( + a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = nb::none(), "keepdims"_a = false, @@ -1114,11 +1127,11 @@ void init_array(nb::module_& m) { "See :func:`logsumexp`.") .def( "mean", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "axis"_a = nb::none(), "keepdims"_a = false, @@ -1127,12 +1140,12 @@ void init_array(nb::module_& m) { "See :func:`mean`.") .def( "std", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, int ddof, - StreamOrDevice s) { - return mlx::core::std( + mx::StreamOrDevice s) { + return mx::std( a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); }, "axis"_a = nb::none(), @@ -1143,12 +1156,13 @@ void init_array(nb::module_& m) { "See :func:`std`.") .def( "var", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, int ddof, - StreamOrDevice s) { - return var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); + mx::StreamOrDevice s) { + return mx::var( + a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); }, "axis"_a = nb::none(), "keepdims"_a = false, @@ -1158,14 +1172,14 @@ void init_array(nb::module_& m) { "See :func:`var`.") .def( "split", - [](const array& a, + [](const mx::array& a, const std::variant>& indices_or_sections, int axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (auto pv = std::get_if(&indices_or_sections); pv) { - return split(a, *pv, axis, s); + return mx::split(a, *pv, axis, s); } else { - return split( + return mx::split( a, std::get>(indices_or_sections), axis, s); } }, @@ -1176,14 +1190,14 @@ void init_array(nb::module_& m) { "See :func:`split`.") .def( "argmin", - [](const array& a, + [](const mx::array& a, std::optional axis, bool keepdims, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return argmin(a, *axis, keepdims, s); + return mx::argmin(a, *axis, keepdims, s); } else { - return argmin(a, keepdims, s); + return mx::argmin(a, keepdims, s); } }, "axis"_a = std::nullopt, @@ -1193,14 +1207,14 @@ void init_array(nb::module_& m) { "See :func:`argmin`.") .def( "argmax", - [](const array& a, + [](const mx::array& a, std::optional axis, bool keepdims, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return argmax(a, *axis, keepdims, s); + return mx::argmax(a, *axis, keepdims, s); } else { - return argmax(a, keepdims, s); + return mx::argmax(a, keepdims, s); } }, "axis"_a = nb::none(), @@ -1210,17 +1224,17 @@ void init_array(nb::module_& m) { "See :func:`argmax`.") .def( "cumsum", - [](const array& a, + [](const mx::array& a, std::optional axis, bool reverse, bool inclusive, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return cumsum(a, *axis, reverse, inclusive, s); + return mx::cumsum(a, *axis, reverse, inclusive, s); } else { // TODO: Implement that in the C++ API as well. See concatenate // above. - return cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "axis"_a = nb::none(), @@ -1231,17 +1245,18 @@ void init_array(nb::module_& m) { "See :func:`cumsum`.") .def( "cumprod", - [](const array& a, + [](const mx::array& a, std::optional axis, bool reverse, bool inclusive, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return cumprod(a, *axis, reverse, inclusive, s); + return mx::cumprod(a, *axis, reverse, inclusive, s); } else { // TODO: Implement that in the C++ API as well. See concatenate // above. - return cumprod(reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cumprod( + mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "axis"_a = nb::none(), @@ -1252,17 +1267,18 @@ void init_array(nb::module_& m) { "See :func:`cumprod`.") .def( "cummax", - [](const array& a, + [](const mx::array& a, std::optional axis, bool reverse, bool inclusive, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return cummax(a, *axis, reverse, inclusive, s); + return mx::cummax(a, *axis, reverse, inclusive, s); } else { // TODO: Implement that in the C++ API as well. See concatenate // above. - return cummax(reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cummax( + mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "axis"_a = nb::none(), @@ -1273,17 +1289,18 @@ void init_array(nb::module_& m) { "See :func:`cummax`.") .def( "cummin", - [](const array& a, + [](const mx::array& a, std::optional axis, bool reverse, bool inclusive, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return cummin(a, *axis, reverse, inclusive, s); + return mx::cummin(a, *axis, reverse, inclusive, s); } else { // TODO: Implement that in the C++ API as well. See concatenate // above. - return cummin(reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cummin( + mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, "axis"_a = nb::none(), @@ -1294,8 +1311,8 @@ void init_array(nb::module_& m) { "See :func:`cummin`.") .def( "round", - [](const array& a, int decimals, StreamOrDevice s) { - return round(a, decimals, s); + [](const mx::array& a, int decimals, mx::StreamOrDevice s) { + return mx::round(a, decimals, s); }, "decimals"_a = 0, nb::kw_only(), @@ -1303,11 +1320,13 @@ void init_array(nb::module_& m) { "See :func:`round`.") .def( "diagonal", - [](const array& a, + [](const mx::array& a, int offset, int axis1, int axis2, - StreamOrDevice s) { return diagonal(a, offset, axis1, axis2, s); }, + mx::StreamOrDevice s) { + return mx::diagonal(a, offset, axis1, axis2, s); + }, "offset"_a = 0, "axis1"_a = 0, "axis2"_a = 1, @@ -1315,7 +1334,9 @@ void init_array(nb::module_& m) { "See :func:`diagonal`.") .def( "diag", - [](const array& a, int k, StreamOrDevice s) { return diag(a, k, s); }, + [](const mx::array& a, int k, mx::StreamOrDevice s) { + return mx::diag(a, k, s); + }, "k"_a = 0, nb::kw_only(), "stream"_a = nb::none(), @@ -1324,17 +1345,17 @@ void init_array(nb::module_& m) { )pbdoc") .def( "conj", - [](const array& a, StreamOrDevice s) { - return mlx::core::conjugate(to_array(a), s); + [](const mx::array& a, mx::StreamOrDevice s) { + return mx::conjugate(to_array(a), s); }, nb::kw_only(), "stream"_a = nb::none(), "See :func:`conj`.") .def( "view", - [](const ScalarOrArray& a, const Dtype& dtype, StreamOrDevice s) { - return view(to_array(a), dtype, s); - }, + [](const ScalarOrArray& a, + const mx::Dtype& dtype, + mx::StreamOrDevice s) { return mx::view(to_array(a), dtype, s); }, "dtype"_a, nb::kw_only(), "stream"_a = nb::none(), diff --git a/python/src/buffer.h b/python/src/buffer.h index 33cda42ca..cca832686 100644 --- a/python/src/buffer.h +++ b/python/src/buffer.h @@ -14,37 +14,37 @@ #define Py_bf_releasebuffer 2 #endif +namespace mx = mlx::core; namespace nb = nanobind; -using namespace mlx::core; -std::string buffer_format(const array& a) { +std::string buffer_format(const mx::array& a) { // https://docs.python.org/3.10/library/struct.html#format-characters switch (a.dtype()) { - case bool_: + case mx::bool_: return "?"; - case uint8: + case mx::uint8: return "B"; - case uint16: + case mx::uint16: return "H"; - case uint32: + case mx::uint32: return "I"; - case uint64: + case mx::uint64: return "Q"; - case int8: + case mx::int8: return "b"; - case int16: + case mx::int16: return "h"; - case int32: + case mx::int32: return "i"; - case int64: + case mx::int64: return "q"; - case float16: + case mx::float16: return "e"; - case float32: + case mx::float32: return "f"; - case bfloat16: + case mx::bfloat16: return "B"; - case complex64: + case mx::complex64: return "Zf\0"; default: { std::ostringstream os; @@ -84,7 +84,7 @@ struct buffer_info { extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { std::memset(view, 0, sizeof(Py_buffer)); - auto a = nb::cast(nb::handle(obj)); + auto a = nb::cast(nb::handle(obj)); { nb::gil_scoped_release nogil; diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 46547ec0c..04c4f05b6 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -16,7 +16,7 @@ enum PyScalarT { namespace nanobind { template <> -struct ndarray_traits { +struct ndarray_traits { static constexpr bool is_complex = false; static constexpr bool is_float = true; static constexpr bool is_bool = false; @@ -36,21 +36,21 @@ int check_shape_dim(int64_t dim) { } template -array nd_array_to_mlx_contiguous( +mx::array nd_array_to_mlx_contiguous( nb::ndarray nd_array, - const Shape& shape, - Dtype dtype) { + const mx::Shape& shape, + mx::Dtype dtype) { // Make a copy of the numpy buffer // Get buffer ptr pass to array constructor auto data_ptr = nd_array.data(); - return array(static_cast(data_ptr), shape, dtype); + return mx::array(static_cast(data_ptr), shape, dtype); } -array nd_array_to_mlx( +mx::array nd_array_to_mlx( nb::ndarray nd_array, - std::optional dtype) { + std::optional dtype) { // Compute the shape and size - Shape shape; + mx::Shape shape; for (int i = 0; i < nd_array.ndim(); i++) { shape.push_back(check_shape_dim(nd_array.shape(i))); } @@ -59,49 +59,49 @@ array nd_array_to_mlx( // Copy data and make array if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(bool_)); + nd_array, shape, dtype.value_or(mx::bool_)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(uint8)); + nd_array, shape, dtype.value_or(mx::uint8)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(uint16)); + nd_array, shape, dtype.value_or(mx::uint16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(uint32)); + nd_array, shape, dtype.value_or(mx::uint32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(uint64)); + nd_array, shape, dtype.value_or(mx::uint64)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(int8)); + nd_array, shape, dtype.value_or(mx::int8)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(int16)); + nd_array, shape, dtype.value_or(mx::int16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(int32)); + nd_array, shape, dtype.value_or(mx::int32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(int64)); - } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(float16)); + nd_array, shape, dtype.value_or(mx::int64)); + } else if (type == nb::dtype()) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(mx::float16)); } else if (type == nb::bfloat16) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(bfloat16)); + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(mx::bfloat16)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(float32)); + nd_array, shape, dtype.value_or(mx::float32)); } else if (type == nb::dtype()) { return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(float32)); + nd_array, shape, dtype.value_or(mx::float32)); } else if (type == nb::dtype>()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(complex64)); + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(mx::complex64)); } else if (type == nb::dtype>()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(complex64)); + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(mx::complex64)); } else { throw std::invalid_argument("Cannot convert numpy array to mlx array."); } @@ -109,7 +109,7 @@ array nd_array_to_mlx( template nb::ndarray mlx_to_nd_array_impl( - array a, + mx::array a, std::optional t = {}) { { nb::gil_scoped_release nogil; @@ -126,48 +126,48 @@ nb::ndarray mlx_to_nd_array_impl( } template -nb::ndarray mlx_to_nd_array(const array& a) { +nb::ndarray mlx_to_nd_array(const mx::array& a) { switch (a.dtype()) { - case bool_: + case mx::bool_: return mlx_to_nd_array_impl(a); - case uint8: + case mx::uint8: return mlx_to_nd_array_impl(a); - case uint16: + case mx::uint16: return mlx_to_nd_array_impl(a); - case uint32: + case mx::uint32: return mlx_to_nd_array_impl(a); - case uint64: + case mx::uint64: return mlx_to_nd_array_impl(a); - case int8: + case mx::int8: return mlx_to_nd_array_impl(a); - case int16: + case mx::int16: return mlx_to_nd_array_impl(a); - case int32: + case mx::int32: return mlx_to_nd_array_impl(a); - case int64: + case mx::int64: return mlx_to_nd_array_impl(a); - case float16: - return mlx_to_nd_array_impl(a); - case bfloat16: + case mx::float16: + return mlx_to_nd_array_impl(a); + case mx::bfloat16: throw nb::type_error("bfloat16 arrays cannot be converted to NumPy."); - case float32: + case mx::float32: return mlx_to_nd_array_impl(a); - case complex64: + case mx::complex64: return mlx_to_nd_array_impl, NDParams...>(a); default: throw nb::type_error("type cannot be converted to NumPy."); } } -nb::ndarray mlx_to_np_array(const array& a) { +nb::ndarray mlx_to_np_array(const mx::array& a) { return mlx_to_nd_array(a); } -nb::ndarray<> mlx_to_dlpack(const array& a) { +nb::ndarray<> mlx_to_dlpack(const mx::array& a) { return mlx_to_nd_array<>(a); } -nb::object to_scalar(array& a) { +nb::object to_scalar(mx::array& a) { if (a.size() != 1) { throw std::invalid_argument( "[convert] Only length-1 arrays can be converted to Python scalars."); @@ -177,31 +177,31 @@ nb::object to_scalar(array& a) { a.eval(); } switch (a.dtype()) { - case bool_: + case mx::bool_: return nb::cast(a.item()); - case uint8: + case mx::uint8: return nb::cast(a.item()); - case uint16: + case mx::uint16: return nb::cast(a.item()); - case uint32: + case mx::uint32: return nb::cast(a.item()); - case uint64: + case mx::uint64: return nb::cast(a.item()); - case int8: + case mx::int8: return nb::cast(a.item()); - case int16: + case mx::int16: return nb::cast(a.item()); - case int32: + case mx::int32: return nb::cast(a.item()); - case int64: + case mx::int64: return nb::cast(a.item()); - case float16: - return nb::cast(static_cast(a.item())); - case float32: + case mx::float16: + return nb::cast(static_cast(a.item())); + case mx::float32: return nb::cast(a.item()); - case bfloat16: - return nb::cast(static_cast(a.item())); - case complex64: + case mx::bfloat16: + return nb::cast(static_cast(a.item())); + case mx::complex64: return nb::cast(a.item>()); default: throw nb::type_error("type cannot be converted to Python scalar."); @@ -209,7 +209,7 @@ nb::object to_scalar(array& a) { } template -nb::list to_list(array& a, size_t index, int dim) { +nb::list to_list(mx::array& a, size_t index, int dim) { nb::list pl; auto stride = a.strides()[dim]; for (int i = 0; i < a.shape(dim); ++i) { @@ -223,7 +223,7 @@ nb::list to_list(array& a, size_t index, int dim) { return pl; } -nb::object tolist(array& a) { +nb::object tolist(mx::array& a) { if (a.ndim() == 0) { return to_scalar(a); } @@ -232,31 +232,31 @@ nb::object tolist(array& a) { a.eval(); } switch (a.dtype()) { - case bool_: + case mx::bool_: return to_list(a, 0, 0); - case uint8: + case mx::uint8: return to_list(a, 0, 0); - case uint16: + case mx::uint16: return to_list(a, 0, 0); - case uint32: + case mx::uint32: return to_list(a, 0, 0); - case uint64: + case mx::uint64: return to_list(a, 0, 0); - case int8: + case mx::int8: return to_list(a, 0, 0); - case int16: + case mx::int16: return to_list(a, 0, 0); - case int32: + case mx::int32: return to_list(a, 0, 0); - case int64: + case mx::int64: return to_list(a, 0, 0); - case float16: - return to_list(a, 0, 0); - case float32: + case mx::float16: + return to_list(a, 0, 0); + case mx::float32: return to_list(a, 0, 0); - case bfloat16: - return to_list(a, 0, 0); - case complex64: + case mx::bfloat16: + return to_list(a, 0, 0); + case mx::complex64: return to_list>(a, 0, 0); default: throw nb::type_error("data type cannot be converted to Python list."); @@ -279,7 +279,7 @@ void fill_vector(T list, std::vector& vals) { template PyScalarT validate_shape( T list, - const Shape& shape, + const mx::Shape& shape, int idx, bool& all_python_primitive_elements) { if (idx >= shape.size()) { @@ -307,9 +307,9 @@ PyScalarT validate_shape( shape, idx + 1, all_python_primitive_elements); - } else if (nb::isinstance(l)) { + } else if (nb::isinstance(l)) { all_python_primitive_elements = false; - auto arr = nb::cast(l); + auto arr = nb::cast(l); if (arr.ndim() + idx + 1 == shape.size() && std::equal( arr.shape().cbegin(), @@ -347,7 +347,7 @@ PyScalarT validate_shape( } template -void get_shape(T list, Shape& shape) { +void get_shape(T list, mx::Shape& shape) { shape.push_back(check_shape_dim(nb::len(list))); if (shape.back() > 0) { auto l = list.begin(); @@ -355,8 +355,8 @@ void get_shape(T list, Shape& shape) { return get_shape(nb::cast(*l), shape); } else if (nb::isinstance(*l)) { return get_shape(nb::cast(*l), shape); - } else if (nb::isinstance(*l)) { - auto arr = nb::cast(*l); + } else if (nb::isinstance(*l)) { + auto arr = nb::cast(*l); for (int i = 0; i < arr.ndim(); i++) { shape.push_back(arr.shape(i)); } @@ -366,54 +366,55 @@ void get_shape(T list, Shape& shape) { } template -array array_from_list_impl( +mx::array array_from_list_impl( T pl, const PyScalarT& inferred_type, - std::optional specified_type, - const Shape& shape) { + std::optional specified_type, + const mx::Shape& shape) { // Make the array switch (inferred_type) { case pybool: { std::vector vals; fill_vector(pl, vals); - return array(vals.begin(), shape, specified_type.value_or(bool_)); + return mx::array(vals.begin(), shape, specified_type.value_or(mx::bool_)); } case pyint: { - auto dtype = specified_type.value_or(int32); - if (dtype == int64) { + auto dtype = specified_type.value_or(mx::int32); + if (dtype == mx::int64) { std::vector vals; fill_vector(pl, vals); - return array(vals.begin(), shape, dtype); - } else if (dtype == uint64) { + return mx::array(vals.begin(), shape, dtype); + } else if (dtype == mx::uint64) { std::vector vals; fill_vector(pl, vals); - return array(vals.begin(), shape, dtype); - } else if (dtype == uint32) { + return mx::array(vals.begin(), shape, dtype); + } else if (dtype == mx::uint32) { std::vector vals; fill_vector(pl, vals); - return array(vals.begin(), shape, dtype); - } else if (issubdtype(dtype, inexact)) { + return mx::array(vals.begin(), shape, dtype); + } else if (mx::issubdtype(dtype, mx::inexact)) { std::vector vals; fill_vector(pl, vals); - return array(vals.begin(), shape, dtype); + return mx::array(vals.begin(), shape, dtype); } else { std::vector vals; fill_vector(pl, vals); - return array(vals.begin(), shape, dtype); + return mx::array(vals.begin(), shape, dtype); } } case pyfloat: { std::vector vals; fill_vector(pl, vals); - return array(vals.begin(), shape, specified_type.value_or(float32)); + return mx::array( + vals.begin(), shape, specified_type.value_or(mx::float32)); } case pycomplex: { std::vector> vals; fill_vector(pl, vals); - return array( - reinterpret_cast(vals.data()), + return mx::array( + reinterpret_cast(vals.data()), shape, - specified_type.value_or(complex64)); + specified_type.value_or(mx::complex64)); } default: { std::ostringstream msg; @@ -425,9 +426,9 @@ array array_from_list_impl( } template -array array_from_list_impl(T pl, std::optional dtype) { +mx::array array_from_list_impl(T pl, std::optional dtype) { // Compute the shape - Shape shape; + mx::Shape shape; get_shape(pl, shape); // Validate the shape and type @@ -440,30 +441,31 @@ array array_from_list_impl(T pl, std::optional dtype) { } // `pl` contains mlx arrays - std::vector arrays; + std::vector arrays; for (auto l : pl) { arrays.push_back(create_array(nb::cast(l), dtype)); } - return stack(arrays); + return mx::stack(arrays); } -array array_from_list(nb::list pl, std::optional dtype) { +mx::array array_from_list(nb::list pl, std::optional dtype) { return array_from_list_impl(pl, dtype); } -array array_from_list(nb::tuple pl, std::optional dtype) { +mx::array array_from_list(nb::tuple pl, std::optional dtype) { return array_from_list_impl(pl, dtype); } -array create_array(ArrayInitType v, std::optional t) { +mx::array create_array(ArrayInitType v, std::optional t) { if (auto pv = std::get_if(&v); pv) { - return array(nb::cast(*pv), t.value_or(bool_)); + return mx::array(nb::cast(*pv), t.value_or(mx::bool_)); } else if (auto pv = std::get_if(&v); pv) { - return array(nb::cast(*pv), t.value_or(int32)); + return mx::array(nb::cast(*pv), t.value_or(mx::int32)); } else if (auto pv = std::get_if(&v); pv) { - return array(nb::cast(*pv), t.value_or(float32)); + return mx::array(nb::cast(*pv), t.value_or(mx::float32)); } else if (auto pv = std::get_if>(&v); pv) { - return array(static_cast(*pv), t.value_or(complex64)); + return mx::array( + static_cast(*pv), t.value_or(mx::complex64)); } else if (auto pv = std::get_if(&v); pv) { return array_from_list(*pv, t); } else if (auto pv = std::get_if(&v); pv) { @@ -472,10 +474,10 @@ array create_array(ArrayInitType v, std::optional t) { nb::ndarray>(&v); pv) { return nd_array_to_mlx(*pv, t); - } else if (auto pv = std::get_if(&v); pv) { - return astype(*pv, t.value_or((*pv).dtype())); + } else if (auto pv = std::get_if(&v); pv) { + return mx::astype(*pv, t.value_or((*pv).dtype())); } else { auto arr = to_array_with_accessor(std::get(v)); - return astype(arr, t.value_or(arr.dtype())); + return mx::astype(arr, t.value_or(arr.dtype())); } } diff --git a/python/src/convert.h b/python/src/convert.h index 3c899ca34..44a090c2b 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -9,15 +9,15 @@ #include "mlx/array.h" #include "mlx/ops.h" +namespace mx = mlx::core; namespace nb = nanobind; -using namespace mlx::core; using ArrayInitType = std::variant< nb::bool_, nb::int_, nb::float_, // Must be above ndarray - array, + mx::array, // Must be above complex nb::ndarray, std::complex, @@ -25,17 +25,17 @@ using ArrayInitType = std::variant< nb::tuple, nb::object>; -array nd_array_to_mlx( +mx::array nd_array_to_mlx( nb::ndarray nd_array, - std::optional dtype); + std::optional dtype); -nb::ndarray mlx_to_np_array(const array& a); -nb::ndarray<> mlx_to_dlpack(const array& a); +nb::ndarray mlx_to_np_array(const mx::array& a); +nb::ndarray<> mlx_to_dlpack(const mx::array& a); -nb::object to_scalar(array& a); +nb::object to_scalar(mx::array& a); -nb::object tolist(array& a); +nb::object tolist(mx::array& a); -array create_array(ArrayInitType v, std::optional t); -array array_from_list(nb::list pl, std::optional dtype); -array array_from_list(nb::tuple pl, std::optional dtype); +mx::array create_array(ArrayInitType v, std::optional t); +mx::array array_from_list(nb::list pl, std::optional dtype); +mx::array array_from_list(nb::tuple pl, std::optional dtype); diff --git a/python/src/device.cpp b/python/src/device.cpp index 1d0c38b74..85b15dd4d 100644 --- a/python/src/device.cpp +++ b/python/src/device.cpp @@ -8,51 +8,54 @@ #include "mlx/device.h" #include "mlx/utils.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; void init_device(nb::module_& m) { - auto device_class = nb::class_( + auto device_class = nb::class_( m, "Device", R"pbdoc(A device to run operations on.)pbdoc"); - nb::enum_(m, "DeviceType") - .value("cpu", Device::DeviceType::cpu) - .value("gpu", Device::DeviceType::gpu) + nb::enum_(m, "DeviceType") + .value("cpu", mx::Device::DeviceType::cpu) + .value("gpu", mx::Device::DeviceType::gpu) .export_values() - .def("__eq__", [](const Device::DeviceType& d, const nb::object& other) { - if (!nb::isinstance(other) && - !nb::isinstance(other)) { - return false; - } - return d == nb::cast(other); - }); - - device_class.def(nb::init(), "type"_a, "index"_a = 0) - .def_ro("type", &Device::type) + .def( + "__eq__", + [](const mx::Device::DeviceType& d, const nb::object& other) { + if (!nb::isinstance(other) && + !nb::isinstance(other)) { + return false; + } + return d == nb::cast(other); + }); + + device_class + .def(nb::init(), "type"_a, "index"_a = 0) + .def_ro("type", &mx::Device::type) .def( "__repr__", - [](const Device& d) { + [](const mx::Device& d) { std::ostringstream os; os << d; return os.str(); }) - .def("__eq__", [](const Device& d, const nb::object& other) { - if (!nb::isinstance(other) && - !nb::isinstance(other)) { + .def("__eq__", [](const mx::Device& d, const nb::object& other) { + if (!nb::isinstance(other) && + !nb::isinstance(other)) { return false; } - return d == nb::cast(other); + return d == nb::cast(other); }); - nb::implicitly_convertible(); + nb::implicitly_convertible(); m.def( "default_device", - &default_device, + &mx::default_device, R"pbdoc(Get the default device.)pbdoc"); m.def( "set_default_device", - &set_default_device, + &mx::set_default_device, "device"_a, R"pbdoc(Set the default device.)pbdoc"); } diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index 697b8bd58..ebce7acb5 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -9,26 +9,27 @@ #include "mlx/distributed/distributed.h" #include "mlx/distributed/ops.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; void init_distributed(nb::module_& parent_module) { auto m = parent_module.def_submodule( "distributed", "mlx.core.distributed: Communication operations"); - nb::class_( + nb::class_( m, "Group", R"pbcopy( An :class:`mlx.core.distributed.Group` represents a group of independent mlx processes that can communicate. )pbcopy") - .def("rank", &distributed::Group::rank, "Get the rank of this process") - .def("size", &distributed::Group::size, "Get the size of the group") + .def( + "rank", &mx::distributed::Group::rank, "Get the rank of this process") + .def("size", &mx::distributed::Group::size, "Get the size of the group") .def( "split", - &distributed::Group::split, + &mx::distributed::Group::split, "color"_a, "key"_a = -1, nb::sig("def split(self, color: int, key: int = -1) -> Group"), @@ -48,14 +49,14 @@ void init_distributed(nb::module_& parent_module) { m.def( "is_available", - &distributed::is_available, + &mx::distributed::is_available, R"pbdoc( Check if a communication backend is available. )pbdoc"); m.def( "init", - &distributed::init, + &mx::distributed::init, "strict"_a = false, nb::sig("def init(strict: bool = False) -> Group"), R"pbdoc( @@ -72,7 +73,7 @@ void init_distributed(nb::module_& parent_module) { m.def( "all_sum", - &distributed::all_sum, + &mx::distributed::all_sum, "x"_a, nb::kw_only(), "group"_a = nb::none(), @@ -98,7 +99,7 @@ void init_distributed(nb::module_& parent_module) { m.def( "all_gather", - &distributed::all_gather, + &mx::distributed::all_gather, "x"_a, nb::kw_only(), "group"_a = nb::none(), @@ -125,7 +126,7 @@ void init_distributed(nb::module_& parent_module) { m.def( "send", - &distributed::send, + &mx::distributed::send, "x"_a, "dst"_a, nb::kw_only(), @@ -152,7 +153,7 @@ void init_distributed(nb::module_& parent_module) { m.def( "recv", - &distributed::recv, + &mx::distributed::recv, "shape"_a, "dtype"_a, "src"_a, @@ -181,7 +182,7 @@ void init_distributed(nb::module_& parent_module) { m.def( "recv_like", - &distributed::recv_like, + &mx::distributed::recv_like, "x"_a, "src"_a, nb::kw_only(), diff --git a/python/src/fast.cpp b/python/src/fast.cpp index cbc8b934d..103c5e76d 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -13,9 +13,9 @@ #include "mlx/fast.h" #include "mlx/ops.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; void init_fast(nb::module_& parent_module) { auto m = @@ -23,7 +23,7 @@ void init_fast(nb::module_& parent_module) { m.def( "rms_norm", - &fast::rms_norm, + &mx::fast::rms_norm, "x"_a, "weight"_a, "eps"_a, @@ -49,7 +49,7 @@ void init_fast(nb::module_& parent_module) { m.def( "layer_norm", - &fast::layer_norm, + &mx::fast::layer_norm, "x"_a, "weight"_a.none(), "bias"_a.none(), @@ -79,7 +79,7 @@ void init_fast(nb::module_& parent_module) { m.def( "rope", - &fast::rope, + &mx::fast::rope, "a"_a, "dims"_a, nb::kw_only(), @@ -114,7 +114,7 @@ void init_fast(nb::module_& parent_module) { m.def( "scaled_dot_product_attention", - &fast::scaled_dot_product_attention, + &mx::fast::scaled_dot_product_attention, "q"_a, "k"_a, "v"_a, @@ -170,7 +170,7 @@ void init_fast(nb::module_& parent_module) { const std::string& header, bool ensure_row_contiguous, bool atomic_outputs) { - auto kernel = fast::metal_kernel( + auto kernel = mx::fast::metal_kernel( name, input_names, output_names, @@ -182,7 +182,7 @@ void init_fast(nb::module_& parent_module) { [kernel = std::move(kernel)]( const std::vector& inputs_, const std::vector>& output_shapes, - const std::vector& output_dtypes, + const std::vector& output_dtypes, std::tuple grid, std::tuple threadgroup, const std::optional< @@ -190,12 +190,12 @@ void init_fast(nb::module_& parent_module) { template_args_ = std::nullopt, std::optional init_value = std::nullopt, bool verbose = false, - StreamOrDevice s = {}) { - std::vector inputs; + mx::StreamOrDevice s = {}) { + std::vector inputs; for (const auto& value : inputs_) { inputs.push_back(to_array(value, std::nullopt)); } - std::vector> + std::vector> template_args; if (template_args_) { for (const auto& [name, value] : template_args_.value()) { @@ -206,8 +206,8 @@ void init_fast(nb::module_& parent_module) { } else if (nb::isinstance(value)) { int int_val = nb::cast(value); template_args.emplace_back(name, int_val); - } else if (nb::isinstance(value)) { - Dtype dtype = nb::cast(value); + } else if (nb::isinstance(value)) { + mx::Dtype dtype = nb::cast(value); template_args.emplace_back(name, dtype); } else { throw std::invalid_argument( diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 44e914d05..986cd8f67 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -9,24 +9,23 @@ #include "mlx/fft.h" #include "mlx/ops.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; - void init_fft(nb::module_& parent_module) { auto m = parent_module.def_submodule( "fft", "mlx.core.fft: Fast Fourier Transforms."); m.def( "fft", - [](const array& a, + [](const mx::array& a, const std::optional& n, int axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (n.has_value()) { - return fft::fft(a, n.value(), axis, s); + return mx::fft::fft(a, n.value(), axis, s); } else { - return fft::fft(a, axis, s); + return mx::fft::fft(a, axis, s); } }, "a"_a, @@ -49,14 +48,14 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "ifft", - [](const array& a, + [](const mx::array& a, const std::optional& n, int axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (n.has_value()) { - return fft::ifft(a, n.value(), axis, s); + return mx::fft::ifft(a, n.value(), axis, s); } else { - return fft::ifft(a, axis, s); + return mx::fft::ifft(a, axis, s); } }, "a"_a, @@ -79,19 +78,19 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "fft2", - [](const array& a, + [](const mx::array& a, const std::optional>& n, const std::optional>& axes, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { - return fft::fftn(a, n.value(), axes.value(), s); + return mx::fft::fftn(a, n.value(), axes.value(), s); } else if (axes.has_value()) { - return fft::fftn(a, axes.value(), s); + return mx::fft::fftn(a, axes.value(), s); } else if (n.has_value()) { throw std::invalid_argument( "[fft2] `axes` should not be `None` if `s` is not `None`."); } else { - return fft::fftn(a, s); + return mx::fft::fftn(a, s); } }, "a"_a, @@ -115,19 +114,19 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "ifft2", - [](const array& a, + [](const mx::array& a, const std::optional>& n, const std::optional>& axes, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { - return fft::ifftn(a, n.value(), axes.value(), s); + return mx::fft::ifftn(a, n.value(), axes.value(), s); } else if (axes.has_value()) { - return fft::ifftn(a, axes.value(), s); + return mx::fft::ifftn(a, axes.value(), s); } else if (n.has_value()) { throw std::invalid_argument( "[ifft2] `axes` should not be `None` if `s` is not `None`."); } else { - return fft::ifftn(a, s); + return mx::fft::ifftn(a, s); } }, "a"_a, @@ -151,19 +150,19 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "fftn", - [](const array& a, + [](const mx::array& a, const std::optional>& n, const std::optional>& axes, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { - return fft::fftn(a, n.value(), axes.value(), s); + return mx::fft::fftn(a, n.value(), axes.value(), s); } else if (axes.has_value()) { - return fft::fftn(a, axes.value(), s); + return mx::fft::fftn(a, axes.value(), s); } else if (n.has_value()) { throw std::invalid_argument( "[fftn] `axes` should not be `None` if `s` is not `None`."); } else { - return fft::fftn(a, s); + return mx::fft::fftn(a, s); } }, "a"_a, @@ -188,19 +187,19 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "ifftn", - [](const array& a, + [](const mx::array& a, const std::optional>& n, const std::optional>& axes, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { - return fft::ifftn(a, n.value(), axes.value(), s); + return mx::fft::ifftn(a, n.value(), axes.value(), s); } else if (axes.has_value()) { - return fft::ifftn(a, axes.value(), s); + return mx::fft::ifftn(a, axes.value(), s); } else if (n.has_value()) { throw std::invalid_argument( "[ifftn] `axes` should not be `None` if `s` is not `None`."); } else { - return fft::ifftn(a, s); + return mx::fft::ifftn(a, s); } }, "a"_a, @@ -225,14 +224,14 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "rfft", - [](const array& a, + [](const mx::array& a, const std::optional& n, int axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (n.has_value()) { - return fft::rfft(a, n.value(), axis, s); + return mx::fft::rfft(a, n.value(), axis, s); } else { - return fft::rfft(a, axis, s); + return mx::fft::rfft(a, axis, s); } }, "a"_a, @@ -260,14 +259,14 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "irfft", - [](const array& a, + [](const mx::array& a, const std::optional& n, int axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (n.has_value()) { - return fft::irfft(a, n.value(), axis, s); + return mx::fft::irfft(a, n.value(), axis, s); } else { - return fft::irfft(a, axis, s); + return mx::fft::irfft(a, axis, s); } }, "a"_a, @@ -294,19 +293,19 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "rfft2", - [](const array& a, + [](const mx::array& a, const std::optional>& n, const std::optional>& axes, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { - return fft::rfftn(a, n.value(), axes.value(), s); + return mx::fft::rfftn(a, n.value(), axes.value(), s); } else if (axes.has_value()) { - return fft::rfftn(a, axes.value(), s); + return mx::fft::rfftn(a, axes.value(), s); } else if (n.has_value()) { throw std::invalid_argument( "[rfft2] `axes` should not be `None` if `s` is not `None`."); } else { - return fft::rfftn(a, s); + return mx::fft::rfftn(a, s); } }, "a"_a, @@ -336,19 +335,19 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "irfft2", - [](const array& a, + [](const mx::array& a, const std::optional>& n, const std::optional>& axes, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { - return fft::irfftn(a, n.value(), axes.value(), s); + return mx::fft::irfftn(a, n.value(), axes.value(), s); } else if (axes.has_value()) { - return fft::irfftn(a, axes.value(), s); + return mx::fft::irfftn(a, axes.value(), s); } else if (n.has_value()) { throw std::invalid_argument( "[irfft2] `axes` should not be `None` if `s` is not `None`."); } else { - return fft::irfftn(a, s); + return mx::fft::irfftn(a, s); } }, "a"_a, @@ -378,19 +377,19 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "rfftn", - [](const array& a, + [](const mx::array& a, const std::optional>& n, const std::optional>& axes, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { - return fft::rfftn(a, n.value(), axes.value(), s); + return mx::fft::rfftn(a, n.value(), axes.value(), s); } else if (axes.has_value()) { - return fft::rfftn(a, axes.value(), s); + return mx::fft::rfftn(a, axes.value(), s); } else if (n.has_value()) { throw std::invalid_argument( "[rfftn] `axes` should not be `None` if `s` is not `None`."); } else { - return fft::rfftn(a, s); + return mx::fft::rfftn(a, s); } }, "a"_a, @@ -420,19 +419,19 @@ void init_fft(nb::module_& parent_module) { )pbdoc"); m.def( "irfftn", - [](const array& a, + [](const mx::array& a, const std::optional>& n, const std::optional>& axes, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { - return fft::irfftn(a, n.value(), axes.value(), s); + return mx::fft::irfftn(a, n.value(), axes.value(), s); } else if (axes.has_value()) { - return fft::irfftn(a, axes.value(), s); + return mx::fft::irfftn(a, axes.value(), s); } else if (n.has_value()) { throw std::invalid_argument( "[irfftn] `axes` should not be `None` if `s` is not `None`."); } else { - return fft::irfftn(a, s); + return mx::fft::irfftn(a, s); } }, "a"_a, diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index d092f30c2..6261f2603 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -43,20 +43,20 @@ void get_slice_params( nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size); } -array get_int_index(nb::object idx, int axis_size) { +mx::array get_int_index(nb::object idx, int axis_size) { int idx_ = nb::cast(idx); idx_ = (idx_ < 0) ? idx_ + axis_size : idx_; - return array(idx_, uint32); + return mx::array(idx_, mx::uint32); } bool is_valid_index_type(const nb::object& obj) { return nb::isinstance(obj) || nb::isinstance(obj) || - nb::isinstance(obj) || obj.is_none() || nb::ellipsis().is(obj) || - nb::isinstance(obj); + nb::isinstance(obj) || obj.is_none() || + nb::ellipsis().is(obj) || nb::isinstance(obj); } -array mlx_get_item_slice(const array& src, const nb::slice& in_slice) { +mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) { // Check input and raise error if 0 dim for parity with np if (src.ndim() == 0) { throw std::invalid_argument( @@ -77,14 +77,14 @@ array mlx_get_item_slice(const array& src, const nb::slice& in_slice) { return slice(src, starts, ends, strides); } -array mlx_get_item_array(const array& src, const array& indices) { +mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) { // Check input and raise error if 0 dim for parity with np if (src.ndim() == 0) { throw std::invalid_argument( "too many indices for array: array is 0-dimensional"); } - if (indices.dtype() == bool_) { + if (indices.dtype() == mx::bool_) { throw std::invalid_argument("boolean indices are not yet supported"); } @@ -93,7 +93,7 @@ array mlx_get_item_array(const array& src, const array& indices) { return take(src, indices, 0); } -array mlx_get_item_int(const array& src, const nb::int_& idx) { +mx::array mlx_get_item_int(const mx::array& src, const nb::int_& idx) { // Check input and raise error if 0 dim for parity with np if (src.ndim() == 0) { throw std::invalid_argument( @@ -105,13 +105,13 @@ array mlx_get_item_int(const array& src, const nb::int_& idx) { return take(src, get_int_index(idx, src.shape(0)), 0); } -array mlx_gather_nd( - array src, +mx::array mlx_gather_nd( + mx::array src, const std::vector& indices, bool gather_first, int& max_dims) { max_dims = 0; - std::vector gather_indices; + std::vector gather_indices; std::vector is_slice(indices.size(), false); int num_slices = 0; // gather all the arrays @@ -127,13 +127,13 @@ array mlx_gather_nd( start = (start < 0) ? start + src.shape(i) : start; end = (end < 0) ? end + src.shape(i) : end; - gather_indices.push_back(arange(start, end, stride, uint32)); + gather_indices.push_back(arange(start, end, stride, mx::uint32)); num_slices++; is_slice[i] = true; } else if (nb::isinstance(idx)) { gather_indices.push_back(get_int_index(idx, src.shape(i))); - } else if (nb::isinstance(idx)) { - auto arr = nb::cast(idx); + } else if (nb::isinstance(idx)) { + auto arr = nb::cast(idx); max_dims = std::max(static_cast(arr.ndim()), max_dims); gather_indices.push_back(arr); } @@ -144,7 +144,7 @@ array mlx_gather_nd( int slice_index = 0; for (int i = 0; i < gather_indices.size(); i++) { if (is_slice[i]) { - Shape index_shape(max_dims + num_slices, 1); + mx::Shape index_shape(max_dims + num_slices, 1); index_shape[max_dims + slice_index] = gather_indices[i].shape(0); gather_indices[i] = reshape(gather_indices[i], std::move(index_shape)); slice_index++; @@ -158,7 +158,7 @@ array mlx_gather_nd( // reshape them so that the int/array indices are last for (int i = 0; i < gather_indices.size(); i++) { if (i < num_slices) { - Shape index_shape(max_dims + num_slices, 1); + mx::Shape index_shape(max_dims + num_slices, 1); index_shape[i] = gather_indices[i].shape(0); gather_indices[i] = reshape(gather_indices[i], std::move(index_shape)); } @@ -241,7 +241,7 @@ auto mlx_expand_ellipsis( return std::make_pair(non_none_indices, indices); } -array mlx_get_item_nd(array src, const nb::tuple& entries) { +mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { // No indices make this a noop if (entries.size() == 0) { return src; @@ -281,7 +281,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { bool have_non_array = false; bool gather_first = false; for (auto& idx : indices) { - if (nb::isinstance(idx) || (nb::isinstance(idx))) { + if (nb::isinstance(idx) || (nb::isinstance(idx))) { if (have_array && have_non_array) { gather_first = true; break; @@ -294,7 +294,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { int n_arr = 0; for (auto& idx : indices) { - n_arr += nb::isinstance(idx); + n_arr += nb::isinstance(idx); } have_array &= n_arr > 0; @@ -304,7 +304,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { // Then find the last array for (last_array = indices.size() - 1; last_array >= 0; last_array--) { auto& idx = indices[last_array]; - if (nb::isinstance(idx) || nb::isinstance(idx)) { + if (nb::isinstance(idx) || nb::isinstance(idx)) { break; } } @@ -340,7 +340,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { } else { for (int i = 0; i < indices.size(); i++) { auto& idx = indices[i]; - if (nb::isinstance(idx) || nb::isinstance(idx)) { + if (nb::isinstance(idx) || nb::isinstance(idx)) { break; } else if (idx.is_none()) { remaining_indices.push_back(idx); @@ -426,11 +426,11 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { return src; } -array mlx_get_item(const array& src, const nb::object& obj) { +mx::array mlx_get_item(const mx::array& src, const nb::object& obj) { if (nb::isinstance(obj)) { return mlx_get_item_slice(src, nb::cast(obj)); - } else if (nb::isinstance(obj)) { - return mlx_get_item_array(src, nb::cast(obj)); + } else if (nb::isinstance(obj)) { + return mlx_get_item_array(src, nb::cast(obj)); } else if (nb::isinstance(obj)) { return mlx_get_item_int(src, nb::cast(obj)); } else if (nb::isinstance(obj)) { @@ -448,10 +448,11 @@ array mlx_get_item(const array& src, const nb::object& obj) { throw std::invalid_argument("Cannot index mlx array using the given type."); } -std::tuple, array, std::vector> mlx_scatter_args_int( - const array& src, +std::tuple, mx::array, std::vector> +mlx_scatter_args_int( + const mx::array& src, const nb::int_& idx, - const array& update) { + const mx::array& update) { if (src.ndim() == 0) { throw std::invalid_argument( "too many indices for array: array is 0-dimensional"); @@ -473,10 +474,11 @@ std::tuple, array, std::vector> mlx_scatter_args_int( {0}}; } -std::tuple, array, std::vector> mlx_scatter_args_array( - const array& src, - const array& indices, - const array& update) { +std::tuple, mx::array, std::vector> +mlx_scatter_args_array( + const mx::array& src, + const mx::array& indices, + const mx::array& update) { if (src.ndim() == 0) { throw std::invalid_argument( "too many indices for array: array is 0-dimensional"); @@ -500,10 +502,11 @@ std::tuple, array, std::vector> mlx_scatter_args_array( return {{indices}, up, {0}}; } -std::tuple, array, std::vector> mlx_scatter_args_slice( - const array& src, +std::tuple, mx::array, std::vector> +mlx_scatter_args_slice( + const mx::array& src, const nb::slice& in_slice, - const array& update) { + const mx::array& update) { // Check input and raise error if 0 dim for parity with np if (src.ndim() == 0) { throw std::invalid_argument( @@ -539,7 +542,7 @@ std::tuple, array, std::vector> mlx_scatter_args_slice( auto up = reshape(update, up_shape); // Build array to mark start of slice - auto idx = array({start}, {1}, uint32); + auto idx = mx::array({start}, {1}, mx::uint32); // Get slice size int slice_size = (end - start); @@ -551,20 +554,21 @@ std::tuple, array, std::vector> mlx_scatter_args_slice( up = broadcast_to(up, up_shape_broadcast); - auto indices = std::vector{idx}; + auto indices = std::vector{idx}; auto axes = std::vector{0}; return {indices, up, axes}; } return mlx_scatter_args_array( - src, arange(start, end, stride, uint32), update); + src, arange(start, end, stride, mx::uint32), update); } -std::tuple, array, std::vector> mlx_scatter_args_nd( - const array& src, +std::tuple, mx::array, std::vector> +mlx_scatter_args_nd( + const mx::array& src, const nb::tuple& entries, - const array& update) { + const mx::array& update) { // Expand ellipses into a series of ':' slices auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries); @@ -623,12 +627,12 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( num_simple_slices_post++; } - } else if (nb::isinstance(idx)) { + } else if (nb::isinstance(idx)) { have_array = true; if (have_array && have_non_array) { arrays_first = true; } - max_dim = std::max(nb::cast(idx).ndim(), max_dim); + max_dim = std::max(nb::cast(idx).ndim(), max_dim); num_arrays++; num_simple_slices_post = 0; } @@ -643,7 +647,7 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( idx_ndim = idx_ndim == 0 ? 1 : idx_ndim; // Go over each index type and translate to the needed scatter args - std::vector arr_indices; + std::vector arr_indices; int slice_num = 0; int array_num = 0; int ax = 0; @@ -668,7 +672,7 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( // If it's a simple slice, we only need to add the start index if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) { - auto idx = array({start}, idx_shape, uint32); + auto idx = mx::array({start}, idx_shape, mx::uint32); slice_shapes.push_back(end - start); arr_indices.push_back(idx); @@ -677,7 +681,7 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( } // Otherwise we expand the slice into indices using arange else { - auto idx = arange(start, end, stride, uint32); + auto idx = arange(start, end, stride, mx::uint32); auto loc = slice_num + (arrays_first ? max_dim : 0); idx_shape[loc] = idx.size(); arr_indices.push_back(reshape(idx, idx_shape)); @@ -696,9 +700,9 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( } else if (pyidx.is_none()) { // We only use the None's for bookeeping dimensions slice_num++; - } else if (nb::isinstance(pyidx)) { + } else if (nb::isinstance(pyidx)) { ax++; - auto idx = nb::cast(pyidx); + auto idx = nb::cast(pyidx); std::vector idx_shape(idx_ndim, 1); // Place the arrays in the correct dimension @@ -748,16 +752,16 @@ std::tuple, array, std::vector> mlx_scatter_args_nd( return {arr_indices, up, axes}; } -std::tuple, array, std::vector> +std::tuple, mx::array, std::vector> mlx_compute_scatter_args( - const array& src, + const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { auto vals = to_array(v, src.dtype()); if (nb::isinstance(obj)) { return mlx_scatter_args_slice(src, nb::cast(obj), vals); - } else if (nb::isinstance(obj)) { - return mlx_scatter_args_array(src, nb::cast(obj), vals); + } else if (nb::isinstance(obj)) { + return mlx_scatter_args_array(src, nb::cast(obj), vals); } else if (nb::isinstance(obj)) { return mlx_scatter_args_int(src, nb::cast(obj), vals); } else if (nb::isinstance(obj)) { @@ -773,7 +777,7 @@ mlx_compute_scatter_args( } auto mlx_slice_update( - const array& src, + const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { // Can't route to slice update if not slice or tuple @@ -784,7 +788,7 @@ auto mlx_slice_update( if (nb::isinstance(obj)) { // Can't route to slice update if any arrays are present for (auto idx : nb::cast(obj)) { - if (nb::isinstance(idx) || nb::isinstance(idx)) { + if (nb::isinstance(idx) || nb::isinstance(idx)) { return std::make_pair(false, src); } } @@ -881,7 +885,10 @@ auto mlx_slice_update( return std::make_pair(true, out); } -void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) { +void mlx_set_item( + mx::array& src, + const nb::object& obj, + const ScalarOrArray& v) { auto [success, out] = mlx_slice_update(src, obj, v); if (success) { src.overwrite_descriptor(out); @@ -897,8 +904,8 @@ void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) { } } -array mlx_add_item( - const array& src, +mx::array mlx_add_item( + const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); @@ -909,8 +916,8 @@ array mlx_add_item( } } -array mlx_subtract_item( - const array& src, +mx::array mlx_subtract_item( + const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); @@ -921,8 +928,8 @@ array mlx_subtract_item( } } -array mlx_multiply_item( - const array& src, +mx::array mlx_multiply_item( + const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); @@ -933,8 +940,8 @@ array mlx_multiply_item( } } -array mlx_divide_item( - const array& src, +mx::array mlx_divide_item( + const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); @@ -945,8 +952,8 @@ array mlx_divide_item( } } -array mlx_maximum_item( - const array& src, +mx::array mlx_maximum_item( + const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); @@ -957,8 +964,8 @@ array mlx_maximum_item( } } -array mlx_minimum_item( - const array& src, +mx::array mlx_minimum_item( + const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); diff --git a/python/src/indexing.h b/python/src/indexing.h index 91ea17233..c941425bf 100644 --- a/python/src/indexing.h +++ b/python/src/indexing.h @@ -7,32 +7,35 @@ #include "mlx/array.h" #include "python/src/utils.h" +namespace mx = mlx::core; namespace nb = nanobind; -using namespace mlx::core; -array mlx_get_item(const array& src, const nb::object& obj); -void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v); -array mlx_add_item( - const array& src, +mx::array mlx_get_item(const mx::array& src, const nb::object& obj); +void mlx_set_item( + mx::array& src, const nb::object& obj, const ScalarOrArray& v); -array mlx_subtract_item( - const array& src, +mx::array mlx_add_item( + const mx::array& src, const nb::object& obj, const ScalarOrArray& v); -array mlx_multiply_item( - const array& src, +mx::array mlx_subtract_item( + const mx::array& src, const nb::object& obj, const ScalarOrArray& v); -array mlx_divide_item( - const array& src, +mx::array mlx_multiply_item( + const mx::array& src, const nb::object& obj, const ScalarOrArray& v); -array mlx_maximum_item( - const array& src, +mx::array mlx_divide_item( + const mx::array& src, const nb::object& obj, const ScalarOrArray& v); -array mlx_minimum_item( - const array& src, +mx::array mlx_maximum_item( + const mx::array& src, + const nb::object& obj, + const ScalarOrArray& v); +mx::array mlx_minimum_item( + const mx::array& src, const nb::object& obj, const ScalarOrArray& v); diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index e2c3aea23..e3dfcb32a 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -10,15 +10,13 @@ #include "mlx/linalg.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; -using namespace mlx::core::linalg; - namespace { -nb::tuple svd_helper(const array& a, StreamOrDevice s /* = {} */) { - const auto result = svd(a, s); +nb::tuple svd_helper(const mx::array& a, mx::StreamOrDevice s /* = {} */) { + const auto result = mx::linalg::svd(a, s); return nb::make_tuple(result.at(0), result.at(1), result.at(2)); } } // namespace @@ -29,11 +27,11 @@ void init_linalg(nb::module_& parent_module) { m.def( "norm", - [](const array& a, + [](const mx::array& a, const std::variant& ord_, const std::variant>& axis_, const bool keepdims, - const StreamOrDevice stream) { + const mx::StreamOrDevice stream) { std::optional> axis = std::nullopt; if (auto pv = std::get_if(&axis_); pv) { axis = std::vector{*pv}; @@ -42,10 +40,10 @@ void init_linalg(nb::module_& parent_module) { } if (std::holds_alternative(ord_)) { - return norm(a, axis, keepdims, stream); + return mx::linalg::norm(a, axis, keepdims, stream); } else { if (auto pv = std::get_if(&ord_); pv) { - return norm(a, *pv, axis, keepdims, stream); + return mx::linalg::norm(a, *pv, axis, keepdims, stream); } double ord; if (auto pv = std::get_if(&ord_); pv) { @@ -53,7 +51,7 @@ void init_linalg(nb::module_& parent_module) { } else { ord = std::get(ord_); } - return norm(a, ord, axis, keepdims, stream); + return mx::linalg::norm(a, ord, axis, keepdims, stream); } }, nb::arg(), @@ -182,7 +180,7 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "qr", - &qr, + &mx::linalg::qr, "a"_a, nb::kw_only(), "stream"_a = nb::none(), @@ -239,7 +237,7 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "inv", - &inv, + &mx::linalg::inv, "a"_a, nb::kw_only(), "stream"_a = nb::none(), @@ -262,7 +260,7 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "tri_inv", - &tri_inv, + &mx::linalg::tri_inv, "a"_a, "upper"_a, nb::kw_only(), @@ -287,7 +285,7 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "cholesky", - &cholesky, + &mx::linalg::cholesky, "a"_a, "upper"_a = false, nb::kw_only(), @@ -317,7 +315,7 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "cholesky_inv", - &cholesky_inv, + &mx::linalg::cholesky_inv, "a"_a, "upper"_a = false, nb::kw_only(), @@ -355,7 +353,7 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "pinv", - &pinv, + &mx::linalg::pinv, "a"_a, nb::kw_only(), "stream"_a = nb::none(), @@ -379,7 +377,7 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "cross", - &cross, + &mx::linalg::cross, "a"_a, "b"_a, "axis"_a = -1, @@ -407,7 +405,7 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "eigvalsh", - &eigvalsh, + &mx::linalg::eigvalsh, "a"_a, "UPLO"_a = "L", nb::kw_only(), @@ -442,9 +440,9 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "eigh", - [](const array& a, const std::string UPLO, StreamOrDevice s) { + [](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) { // TODO avoid cast? - auto result = eigh(a, UPLO, s); + auto result = mx::linalg::eigh(a, UPLO, s); return nb::make_tuple(result.first, result.second); }, "a"_a, diff --git a/python/src/load.cpp b/python/src/load.cpp index 84530bd46..66e8ecc5a 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -14,9 +14,9 @@ #include "python/src/load.h" #include "python/src/utils.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; /////////////////////////////////////////////////////////////////////////////// // Helpers @@ -86,7 +86,7 @@ class ZipFileWrapper { // Loading /////////////////////////////////////////////////////////////////////////////// -class PyFileReader : public io::Reader { +class PyFileReader : public mx::io::Reader { public: PyFileReader(nb::object file) : pyistream_(file), @@ -168,14 +168,14 @@ class PyFileReader : public io::Reader { }; std::pair< - std::unordered_map, + std::unordered_map, std::unordered_map> -mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) { +mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) { if (nb::isinstance(file)) { // Assume .safetensors file path string - return load_safetensors(nb::cast(file), s); + return mx::load_safetensors(nb::cast(file), s); } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately - auto res = load_safetensors(std::make_shared(file), s); + auto res = mx::load_safetensors(std::make_shared(file), s); { nb::gil_scoped_release gil; for (auto& [key, arr] : std::get<0>(res)) { @@ -189,17 +189,17 @@ mlx_load_safetensor_helper(nb::object file, StreamOrDevice s) { "[load_safetensors] Input must be a file-like object, or string"); } -GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s) { +mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) { if (nb::isinstance(file)) { // Assume .gguf file path string - return load_gguf(nb::cast(file), s); + return mx::load_gguf(nb::cast(file), s); } throw std::invalid_argument("[load_gguf] Input must be a string"); } -std::unordered_map mlx_load_npz_helper( +std::unordered_map mlx_load_npz_helper( nb::object file, - StreamOrDevice s) { + mx::StreamOrDevice s) { bool own_file = nb::isinstance(file); nb::module_ zipfile = nb::module_::import_("zipfile"); @@ -209,7 +209,7 @@ std::unordered_map mlx_load_npz_helper( "opened with zipfile.ZipFile"); } // Output dictionary filename in zip -> loaded array - std::unordered_map array_dict; + std::unordered_map array_dict; // Create python ZipFile object ZipFileWrapper zipfile_object(zipfile, file); @@ -218,7 +218,7 @@ std::unordered_map mlx_load_npz_helper( nb::object sub_file = zipfile_object.open(st); // Create array from python file stream - auto arr = load(std::make_shared(sub_file), s); + auto arr = mx::load(std::make_shared(sub_file), s); // Remove .npy from file if it is there auto key = st; @@ -240,12 +240,12 @@ std::unordered_map mlx_load_npz_helper( return array_dict; } -array mlx_load_npy_helper(nb::object file, StreamOrDevice s) { +mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) { if (nb::isinstance(file)) { // Assume .npy file path string - return load(nb::cast(file), s); + return mx::load(nb::cast(file), s); } else if (is_istream_object(file)) { // If we don't own the stream and it was passed to us, eval immediately - auto arr = load(std::make_shared(file), s); + auto arr = mx::load(std::make_shared(file), s); { nb::gil_scoped_release gil; arr.eval(); @@ -260,7 +260,7 @@ LoadOutputTypes mlx_load_helper( nb::object file, std::optional format, bool return_metadata, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (!format.has_value()) { std::string fname; if (nb::isinstance(file)) { @@ -309,7 +309,7 @@ LoadOutputTypes mlx_load_helper( // Saving /////////////////////////////////////////////////////////////////////////////// -class PyFileWriter : public io::Writer { +class PyFileWriter : public mx::io::Writer { public: PyFileWriter(nb::object file) : pyostream_(file), @@ -382,15 +382,15 @@ class PyFileWriter : public io::Writer { nb::object tell_func_; }; -void mlx_save_helper(nb::object file, array a) { +void mlx_save_helper(nb::object file, mx::array a) { if (nb::isinstance(file)) { - save(nb::cast(file), a); + mx::save(nb::cast(file), a); return; } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); { nb::gil_scoped_release gil; - save(writer, a); + mx::save(writer, a); } return; @@ -419,8 +419,9 @@ void mlx_savez_helper( } // Collect args and kwargs - auto arrays_dict = nb::cast>(kwargs); - auto arrays_list = nb::cast>(args); + auto arrays_dict = + nb::cast>(kwargs); + auto arrays_list = nb::cast>(args); for (int i = 0; i < arrays_list.size(); i++) { std::string arr_name = "arr_" + std::to_string(i); @@ -447,7 +448,7 @@ void mlx_savez_helper( auto writer = std::make_shared(py_ostream); { nb::gil_scoped_release nogil; - save(writer, a); + mx::save(writer, a); } } @@ -470,17 +471,18 @@ void mlx_save_safetensor_helper( } else { metadata_map = std::unordered_map(); } - auto arrays_map = nb::cast>(d); + auto arrays_map = nb::cast>(d); if (nb::isinstance(file)) { { nb::gil_scoped_release nogil; - save_safetensors(nb::cast(file), arrays_map, metadata_map); + mx::save_safetensors( + nb::cast(file), arrays_map, metadata_map); } } else if (is_ostream_object(file)) { auto writer = std::make_shared(file); { nb::gil_scoped_release nogil; - save_safetensors(writer, arrays_map, metadata_map); + mx::save_safetensors(writer, arrays_map, metadata_map); } } else { throw std::invalid_argument( @@ -492,19 +494,20 @@ void mlx_save_gguf_helper( nb::object file, nb::dict a, std::optional m) { - auto arrays_map = nb::cast>(a); + auto arrays_map = nb::cast>(a); if (nb::isinstance(file)) { if (m) { auto metadata_map = - nb::cast>(m.value()); + nb::cast>( + m.value()); { nb::gil_scoped_release nogil; - save_gguf(nb::cast(file), arrays_map, metadata_map); + mx::save_gguf(nb::cast(file), arrays_map, metadata_map); } } else { { nb::gil_scoped_release nogil; - save_gguf(nb::cast(file), arrays_map); + mx::save_gguf(nb::cast(file), arrays_map); } } } else { diff --git a/python/src/load.h b/python/src/load.h index 90ddeb8b3..4188822f9 100644 --- a/python/src/load.h +++ b/python/src/load.h @@ -14,22 +14,24 @@ #include #include "mlx/io.h" +namespace mx = mlx::core; namespace nb = nanobind; -using namespace mlx::core; using LoadOutputTypes = std::variant< - array, - std::unordered_map, - SafetensorsLoad, - GGUFLoad>; + mx::array, + std::unordered_map, + mx::SafetensorsLoad, + mx::GGUFLoad>; -SafetensorsLoad mlx_load_safetensor_helper(nb::object file, StreamOrDevice s); +mx::SafetensorsLoad mlx_load_safetensor_helper( + nb::object file, + mx::StreamOrDevice s); void mlx_save_safetensor_helper( nb::object file, nb::dict d, std::optional m); -GGUFLoad mlx_load_gguf_helper(nb::object file, StreamOrDevice s); +mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s); void mlx_save_gguf_helper( nb::object file, @@ -40,8 +42,8 @@ LoadOutputTypes mlx_load_helper( nb::object file, std::optional format, bool return_metadata, - StreamOrDevice s); -void mlx_save_helper(nb::object file, array a); + mx::StreamOrDevice s); +void mlx_save_helper(nb::object file, mx::array a); void mlx_savez_helper( nb::object file, nb::args args, diff --git a/python/src/metal.cpp b/python/src/metal.cpp index c08bd6c50..646956af8 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -8,22 +8,21 @@ #include #include +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; - void init_metal(nb::module_& m) { nb::module_ metal = m.def_submodule("metal", "mlx.metal"); metal.def( "is_available", - &metal::is_available, + &mx::metal::is_available, R"pbdoc( Check if the Metal back-end is available. )pbdoc"); metal.def( "get_active_memory", - &metal::get_active_memory, + &mx::metal::get_active_memory, R"pbdoc( Get the actively used memory in bytes. @@ -32,7 +31,7 @@ void init_metal(nb::module_& m) { )pbdoc"); metal.def( "get_peak_memory", - &metal::get_peak_memory, + &mx::metal::get_peak_memory, R"pbdoc( Get the peak amount of used memory in bytes. @@ -41,13 +40,13 @@ void init_metal(nb::module_& m) { )pbdoc"); metal.def( "reset_peak_memory", - &metal::reset_peak_memory, + &mx::metal::reset_peak_memory, R"pbdoc( Reset the peak memory to zero. )pbdoc"); metal.def( "get_cache_memory", - &metal::get_cache_memory, + &mx::metal::get_cache_memory, R"pbdoc( Get the cache size in bytes. @@ -56,7 +55,7 @@ void init_metal(nb::module_& m) { )pbdoc"); metal.def( "set_memory_limit", - &metal::set_memory_limit, + &mx::metal::set_memory_limit, "limit"_a, nb::kw_only(), "relaxed"_a = true, @@ -81,7 +80,7 @@ void init_metal(nb::module_& m) { )pbdoc"); metal.def( "set_cache_limit", - &metal::set_cache_limit, + &mx::metal::set_cache_limit, "limit"_a, R"pbdoc( Set the free cache limit. @@ -101,7 +100,7 @@ void init_metal(nb::module_& m) { )pbdoc"); metal.def( "set_wired_limit", - &metal::set_wired_limit, + &mx::metal::set_wired_limit, "limit"_a, R"pbdoc( Set the wired size limit. @@ -133,7 +132,7 @@ void init_metal(nb::module_& m) { )pbdoc"); metal.def( "clear_cache", - &metal::clear_cache, + &mx::metal::clear_cache, R"pbdoc( Clear the memory cache. @@ -142,7 +141,7 @@ void init_metal(nb::module_& m) { metal.def( "start_capture", - &metal::start_capture, + &mx::metal::start_capture, "path"_a, R"pbdoc( Start a Metal capture. @@ -153,13 +152,13 @@ void init_metal(nb::module_& m) { )pbdoc"); metal.def( "stop_capture", - &metal::stop_capture, + &mx::metal::stop_capture, R"pbdoc( Stop a Metal capture. )pbdoc"); metal.def( "device_info", - &metal::device_info, + &mx::metal::device_info, R"pbdoc( Get information about the GPU device and system settings. diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1becce7e8..d268a865e 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -18,17 +18,17 @@ #include "python/src/load.h" #include "python/src/utils.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; using Scalar = std::variant; -Dtype scalar_to_dtype(Scalar scalar) { +mx::Dtype scalar_to_dtype(Scalar scalar) { if (std::holds_alternative(scalar)) { - return int32; + return mx::int32; } else { - return float32; + return mx::float32; } } @@ -43,7 +43,7 @@ double scalar_to_double(Scalar s) { void init_ops(nb::module_& m) { m.def( "reshape", - &reshape, + &mx::reshape, nb::arg(), "shape"_a, nb::kw_only(), @@ -64,10 +64,12 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "flatten", - [](const array& a, + [](const mx::array& a, int start_axis, int end_axis, - const StreamOrDevice& s) { return flatten(a, start_axis, end_axis); }, + const mx::StreamOrDevice& s) { + return mx::flatten(a, start_axis, end_axis); + }, nb::arg(), "start_axis"_a = 0, "end_axis"_a = -1, @@ -103,13 +105,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "squeeze", - [](const array& a, const IntOrVec& v, const StreamOrDevice& s) { + [](const mx::array& a, const IntOrVec& v, const mx::StreamOrDevice& s) { if (std::holds_alternative(v)) { - return squeeze(a, s); + return mx::squeeze(a, s); } else if (auto pv = std::get_if(&v); pv) { - return squeeze(a, *pv, s); + return mx::squeeze(a, *pv, s); } else { - return squeeze(a, std::get>(v), s); + return mx::squeeze(a, std::get>(v), s); } }, nb::arg(), @@ -132,13 +134,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "expand_dims", - [](const array& a, + [](const mx::array& a, const std::variant>& v, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (auto pv = std::get_if(&v); pv) { - return expand_dims(a, *pv, s); + return mx::expand_dims(a, *pv, s); } else { - return expand_dims(a, std::get>(v), s); + return mx::expand_dims(a, std::get>(v), s); } }, nb::arg(), @@ -159,8 +161,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "abs", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::abs(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::abs(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -178,8 +180,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sign", - [](const ScalarOrArray& a, StreamOrDevice s) { - return sign(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::sign(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -197,8 +199,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "negative", - [](const ScalarOrArray& a, StreamOrDevice s) { - return negative(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::negative(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -216,9 +218,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "add", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return add(a, b, s); + return mx::add(a, b, s); }, nb::arg(), nb::arg(), @@ -241,9 +245,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "subtract", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return subtract(a, b, s); + return mx::subtract(a, b, s); }, nb::arg(), nb::arg(), @@ -266,9 +272,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "multiply", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return multiply(a, b, s); + return mx::multiply(a, b, s); }, nb::arg(), nb::arg(), @@ -291,9 +299,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "divide", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return divide(a, b, s); + return mx::divide(a, b, s); }, nb::arg(), nb::arg(), @@ -316,9 +326,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "divmod", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return divmod(a, b, s); + return mx::divmod(a, b, s); }, nb::arg(), nb::arg(), @@ -342,9 +354,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "floor_divide", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return floor_divide(a, b, s); + return mx::floor_divide(a, b, s); }, nb::arg(), nb::arg(), @@ -367,9 +381,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "remainder", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return remainder(a, b, s); + return mx::remainder(a, b, s); }, nb::arg(), nb::arg(), @@ -393,9 +409,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "equal", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return equal(a, b, s); + return mx::equal(a, b, s); }, nb::arg(), nb::arg(), @@ -418,9 +436,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "not_equal", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return not_equal(a, b, s); + return mx::not_equal(a, b, s); }, nb::arg(), nb::arg(), @@ -443,9 +463,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "less", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return less(a, b, s); + return mx::less(a, b, s); }, nb::arg(), nb::arg(), @@ -468,9 +490,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "less_equal", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return less_equal(a, b, s); + return mx::less_equal(a, b, s); }, nb::arg(), nb::arg(), @@ -493,9 +517,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "greater", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return greater(a, b, s); + return mx::greater(a, b, s); }, nb::arg(), nb::arg(), @@ -518,9 +544,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "greater_equal", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return greater_equal(a, b, s); + return mx::greater_equal(a, b, s); }, nb::arg(), nb::arg(), @@ -546,9 +574,9 @@ void init_ops(nb::module_& m) { [](const ScalarOrArray& a_, const ScalarOrArray& b_, bool equal_nan, - StreamOrDevice s) { + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return array_equal(a, b, equal_nan, s); + return mx::array_equal(a, b, equal_nan, s); }, nb::arg(), nb::arg(), @@ -575,7 +603,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "matmul", - &matmul, + &mx::matmul, nb::arg(), nb::arg(), nb::kw_only(), @@ -607,8 +635,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "square", - [](const ScalarOrArray& a, StreamOrDevice s) { - return square(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::square(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -626,8 +654,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sqrt", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::sqrt(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::sqrt(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -645,8 +673,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "rsqrt", - [](const ScalarOrArray& a, StreamOrDevice s) { - return rsqrt(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::rsqrt(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -664,8 +692,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "reciprocal", - [](const ScalarOrArray& a, StreamOrDevice s) { - return reciprocal(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::reciprocal(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -683,8 +711,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "logical_not", - [](const ScalarOrArray& a, StreamOrDevice s) { - return logical_not(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::logical_not(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -702,8 +730,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "logical_and", - [](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) { - return logical_and(to_array(a), to_array(b), s); + [](const ScalarOrArray& a, const ScalarOrArray& b, mx::StreamOrDevice s) { + return mx::logical_and(to_array(a), to_array(b), s); }, nb::arg(), nb::arg(), @@ -724,8 +752,8 @@ void init_ops(nb::module_& m) { m.def( "logical_or", - [](const ScalarOrArray& a, const ScalarOrArray& b, StreamOrDevice s) { - return logical_or(to_array(a), to_array(b), s); + [](const ScalarOrArray& a, const ScalarOrArray& b, mx::StreamOrDevice s) { + return mx::logical_or(to_array(a), to_array(b), s); }, nb::arg(), nb::arg(), @@ -745,9 +773,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "logaddexp", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return logaddexp(a, b, s); + return mx::logaddexp(a, b, s); }, nb::arg(), nb::arg(), @@ -772,8 +802,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "exp", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::exp(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::exp(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -791,8 +821,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "expm1", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::expm1(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::expm1(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -812,8 +842,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "erf", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::erf(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::erf(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -834,8 +864,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "erfinv", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::erfinv(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::erfinv(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -853,8 +883,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sin", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::sin(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::sin(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -872,8 +902,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "cos", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::cos(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::cos(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -891,8 +921,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "tan", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::tan(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::tan(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -910,8 +940,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arcsin", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::arcsin(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::arcsin(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -929,8 +959,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arccos", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::arccos(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::arccos(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -948,8 +978,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arctan", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::arctan(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::arctan(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -967,7 +997,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arctan2", - &mlx::core::arctan2, + &mx::arctan2, nb::arg(), nb::arg(), nb::kw_only(), @@ -986,8 +1016,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sinh", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::sinh(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::sinh(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1005,8 +1035,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "cosh", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::cosh(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::cosh(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1024,8 +1054,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "tanh", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::tanh(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::tanh(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1043,8 +1073,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arcsinh", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::arcsinh(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::arcsinh(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1062,8 +1092,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arccosh", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::arccosh(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::arccosh(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1081,8 +1111,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "arctanh", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::arctanh(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::arctanh(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1100,8 +1130,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "degrees", - [](const ScalarOrArray& a, StreamOrDevice s) { - return degrees(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::degrees(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1119,8 +1149,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "radians", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::radians(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::radians(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1138,8 +1168,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "log", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::log(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::log(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1157,8 +1187,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "log2", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::log2(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::log2(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1176,8 +1206,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "log10", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::log10(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::log10(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1195,8 +1225,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "log1p", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::log1p(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::log1p(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1214,7 +1244,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "stop_gradient", - &stop_gradient, + &mx::stop_gradient, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1236,8 +1266,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sigmoid", - [](const ScalarOrArray& a, StreamOrDevice s) { - return sigmoid(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::sigmoid(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1260,9 +1290,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "power", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return power(a, b, s); + return mx::power(a, b, s); }, nb::arg(), nb::arg(), @@ -1288,17 +1320,17 @@ void init_ops(nb::module_& m) { [](Scalar start, Scalar stop, const std::optional& step, - const std::optional& dtype_, - StreamOrDevice s) { + const std::optional& dtype_, + mx::StreamOrDevice s) { // Determine the final dtype based on input types - Dtype dtype = dtype_ + mx::Dtype dtype = dtype_ ? *dtype_ - : promote_types( + : mx::promote_types( scalar_to_dtype(start), - step ? promote_types( + step ? mx::promote_types( scalar_to_dtype(stop), scalar_to_dtype(*step)) : scalar_to_dtype(stop)); - return arange( + return mx::arange( scalar_to_double(start), scalar_to_double(stop), step ? scalar_to_double(*step) : 1.0, @@ -1338,13 +1370,13 @@ void init_ops(nb::module_& m) { "arange", [](Scalar stop, const std::optional& step, - const std::optional& dtype_, - StreamOrDevice s) { - Dtype dtype = dtype_ ? *dtype_ + const std::optional& dtype_, + mx::StreamOrDevice s) { + mx::Dtype dtype = dtype_ ? *dtype_ : step - ? promote_types(scalar_to_dtype(stop), scalar_to_dtype(*step)) + ? mx::promote_types(scalar_to_dtype(stop), scalar_to_dtype(*step)) : scalar_to_dtype(stop); - return arange( + return mx::arange( 0.0, scalar_to_double(stop), step ? scalar_to_double(*step) : 1.0, @@ -1363,19 +1395,19 @@ void init_ops(nb::module_& m) { [](Scalar start, Scalar stop, int num, - std::optional dtype, - StreamOrDevice s) { - return linspace( + std::optional dtype, + mx::StreamOrDevice s) { + return mx::linspace( scalar_to_double(start), scalar_to_double(stop), num, - dtype.value_or(float32), + dtype.value_or(mx::float32), s); }, "start"_a, "stop"_a, "num"_a = 50, - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, "stream"_a = nb::none(), nb::sig( "def linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"), @@ -1394,17 +1426,17 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "take", - [](const array& a, - const std::variant& indices, + [](const mx::array& a, + const std::variant& indices, const std::optional& axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (auto pv = std::get_if(&indices); pv) { auto idx = nb::cast(*pv); - return axis ? take(a, idx, axis.value(), s) : take(a, idx, s); + return axis ? mx::take(a, idx, axis.value(), s) : mx::take(a, idx, s); } else { - auto indices_ = std::get(indices); - return axis ? take(a, indices_, axis.value(), s) - : take(a, indices_, s); + auto indices_ = std::get(indices); + return axis ? mx::take(a, indices_, axis.value(), s) + : mx::take(a, indices_, s); } }, nb::arg(), @@ -1434,14 +1466,14 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "take_along_axis", - [](const array& a, - const array& indices, + [](const mx::array& a, + const mx::array& indices, const std::optional& axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis.has_value()) { - return take_along_axis(a, indices, axis.value(), s); + return mx::take_along_axis(a, indices, axis.value(), s); } else { - return take_along_axis(reshape(a, {-1}, s), indices, 0, s); + return mx::take_along_axis(mx::reshape(a, {-1}, s), indices, 0, s); } }, nb::arg(), @@ -1467,16 +1499,17 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "put_along_axis", - [](const array& a, - const array& indices, - const array& values, + [](const mx::array& a, + const mx::array& indices, + const mx::array& values, const std::optional& axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis.has_value()) { - return put_along_axis(a, indices, values, axis.value(), s); + return mx::put_along_axis(a, indices, values, axis.value(), s); } else { - return reshape( - put_along_axis(reshape(a, {-1}, s), indices, values, 0, s), + return mx::reshape( + mx::put_along_axis( + mx::reshape(a, {-1}, s), indices, values, 0, s), a.shape(), s); } @@ -1510,12 +1543,12 @@ void init_ops(nb::module_& m) { "full", [](const std::variant>& shape, const ScalarOrArray& vals, - std::optional dtype, - StreamOrDevice s) { + std::optional dtype, + mx::StreamOrDevice s) { if (auto pv = std::get_if(&shape); pv) { - return full({*pv}, to_array(vals, dtype), s); + return mx::full({*pv}, to_array(vals, dtype), s); } else { - return full( + return mx::full( std::get>(shape), to_array(vals, dtype), s); } }, @@ -1544,17 +1577,17 @@ void init_ops(nb::module_& m) { m.def( "zeros", [](const std::variant>& shape, - std::optional dtype, - StreamOrDevice s) { - auto t = dtype.value_or(float32); + std::optional dtype, + mx::StreamOrDevice s) { + auto t = dtype.value_or(mx::float32); if (auto pv = std::get_if(&shape); pv) { - return zeros({*pv}, t, s); + return mx::zeros({*pv}, t, s); } else { - return zeros(std::get>(shape), t, s); + return mx::zeros(std::get>(shape), t, s); } }, "shape"_a, - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, nb::kw_only(), "stream"_a = nb::none(), nb::sig( @@ -1572,7 +1605,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "zeros_like", - &zeros_like, + &mx::zeros_like, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1590,17 +1623,17 @@ void init_ops(nb::module_& m) { m.def( "ones", [](const std::variant>& shape, - std::optional dtype, - StreamOrDevice s) { - auto t = dtype.value_or(float32); + std::optional dtype, + mx::StreamOrDevice s) { + auto t = dtype.value_or(mx::float32); if (auto pv = std::get_if(&shape); pv) { - return ones({*pv}, t, s); + return mx::ones({*pv}, t, s); } else { - return ones(std::get>(shape), t, s); + return mx::ones(std::get>(shape), t, s); } }, "shape"_a, - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, nb::kw_only(), "stream"_a = nb::none(), nb::sig( @@ -1618,7 +1651,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "ones_like", - &ones_like, + &mx::ones_like, nb::arg(), nb::kw_only(), "stream"_a = nb::none(), @@ -1638,14 +1671,14 @@ void init_ops(nb::module_& m) { [](int n, std::optional m, int k, - std::optional dtype, - StreamOrDevice s) { - return eye(n, m.value_or(n), k, dtype.value_or(float32), s); + std::optional dtype, + mx::StreamOrDevice s) { + return mx::eye(n, m.value_or(n), k, dtype.value_or(mx::float32), s); }, "n"_a, "m"_a = nb::none(), "k"_a = 0, - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, nb::kw_only(), "stream"_a = nb::none(), nb::sig( @@ -1665,11 +1698,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "identity", - [](int n, std::optional dtype, StreamOrDevice s) { - return identity(n, dtype.value_or(float32), s); + [](int n, std::optional dtype, mx::StreamOrDevice s) { + return mx::identity(n, dtype.value_or(mx::float32), s); }, "n"_a, - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, nb::kw_only(), "stream"_a = nb::none(), nb::sig( @@ -1690,14 +1723,14 @@ void init_ops(nb::module_& m) { [](int n, std::optional m, int k, - std::optional type, - StreamOrDevice s) { - return tri(n, m.value_or(n), k, type.value_or(float32), s); + std::optional type, + mx::StreamOrDevice s) { + return mx::tri(n, m.value_or(n), k, type.value_or(mx::float32), s); }, "n"_a, "m"_a = nb::none(), "k"_a = 0, - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, nb::kw_only(), "stream"_a = nb::none(), nb::sig( @@ -1717,7 +1750,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "tril", - &tril, + &mx::tril, "x"_a, "k"_a = 0, nb::kw_only(), @@ -1737,7 +1770,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "triu", - &triu, + &mx::triu, "x"_a, "k"_a = 0, nb::kw_only(), @@ -1757,7 +1790,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "allclose", - &allclose, + &mx::allclose, nb::arg(), nb::arg(), "rtol"_a = 1e-5, @@ -1794,7 +1827,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "isclose", - &isclose, + &mx::isclose, nb::arg(), nb::arg(), "rtol"_a = 1e-5, @@ -1832,11 +1865,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "all", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, nb::arg(), "axis"_a = nb::none(), @@ -1861,11 +1894,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "any", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, nb::arg(), "axis"_a = nb::none(), @@ -1890,9 +1923,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "minimum", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return minimum(a, b, s); + return mx::minimum(a, b, s); }, nb::arg(), nb::arg(), @@ -1915,9 +1950,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "maximum", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return maximum(a, b, s); + return mx::maximum(a, b, s); }, nb::arg(), nb::arg(), @@ -1940,8 +1977,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "floor", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::floor(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::floor(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1959,8 +1996,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "ceil", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::ceil(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::ceil(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1978,8 +2015,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "isnan", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::isnan(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::isnan(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -1997,8 +2034,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "isinf", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::isinf(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::isinf(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -2016,8 +2053,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "isfinite", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::isfinite(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::isfinite(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -2037,8 +2074,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "isposinf", - [](const ScalarOrArray& a, StreamOrDevice s) { - return isposinf(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::isposinf(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -2057,8 +2094,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "isneginf", - [](const ScalarOrArray& a, StreamOrDevice s) { - return isneginf(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::isneginf(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -2077,7 +2114,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "moveaxis", - &moveaxis, + &mx::moveaxis, nb::arg(), "source"_a, "destination"_a, @@ -2098,7 +2135,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "swapaxes", - &swapaxes, + &mx::swapaxes, nb::arg(), "axis1"_a, "axis2"_a, @@ -2119,13 +2156,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "transpose", - [](const array& a, + [](const mx::array& a, const std::optional>& axes, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axes.has_value()) { - return transpose(a, *axes, s); + return mx::transpose(a, *axes, s); } else { - return transpose(a, s); + return mx::transpose(a, s); } }, nb::arg(), @@ -2147,13 +2184,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "permute_dims", - [](const array& a, + [](const mx::array& a, const std::optional>& axes, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axes.has_value()) { - return transpose(a, *axes, s); + return mx::transpose(a, *axes, s); } else { - return transpose(a, s); + return mx::transpose(a, s); } }, nb::arg(), @@ -2167,11 +2204,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sum", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, "array"_a, "axis"_a = nb::none(), @@ -2196,11 +2233,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "prod", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, nb::arg(), "axis"_a = nb::none(), @@ -2225,11 +2262,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "min", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, nb::arg(), "axis"_a = nb::none(), @@ -2254,11 +2291,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "max", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, nb::arg(), "axis"_a = nb::none(), @@ -2283,11 +2320,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "logsumexp", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, nb::arg(), "axis"_a = nb::none(), @@ -2318,11 +2355,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "mean", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, - StreamOrDevice s) { - return mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + mx::StreamOrDevice s) { + return mx::mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); }, nb::arg(), "axis"_a = nb::none(), @@ -2347,12 +2384,12 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "var", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, int ddof, - StreamOrDevice s) { - return var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); + mx::StreamOrDevice s) { + return mx::var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); }, nb::arg(), "axis"_a = nb::none(), @@ -2380,13 +2417,12 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "std", - [](const array& a, + [](const mx::array& a, const IntOrVec& axis, bool keepdims, int ddof, - StreamOrDevice s) { - return mlx::core::std( - a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); + mx::StreamOrDevice s) { + return mx::std(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); }, nb::arg(), "axis"_a = nb::none(), @@ -2414,14 +2450,14 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "split", - [](const array& a, + [](const mx::array& a, const std::variant>& indices_or_sections, int axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (auto pv = std::get_if(&indices_or_sections); pv) { - return split(a, *pv, axis, s); + return mx::split(a, *pv, axis, s); } else { - return split( + return mx::split( a, std::get>(indices_or_sections), axis, s); } }, @@ -2449,14 +2485,14 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "argmin", - [](const array& a, + [](const mx::array& a, std::optional axis, bool keepdims, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return argmin(a, *axis, keepdims, s); + return mx::argmin(a, *axis, keepdims, s); } else { - return argmin(a, keepdims, s); + return mx::argmin(a, keepdims, s); } }, nb::arg(), @@ -2481,14 +2517,14 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "argmax", - [](const array& a, + [](const mx::array& a, std::optional axis, bool keepdims, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return argmax(a, *axis, keepdims, s); + return mx::argmax(a, *axis, keepdims, s); } else { - return argmax(a, keepdims, s); + return mx::argmax(a, keepdims, s); } }, nb::arg(), @@ -2513,11 +2549,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "sort", - [](const array& a, std::optional axis, StreamOrDevice s) { + [](const mx::array& a, std::optional axis, mx::StreamOrDevice s) { if (axis) { - return sort(a, *axis, s); + return mx::sort(a, *axis, s); } else { - return sort(a, s); + return mx::sort(a, s); } }, nb::arg(), @@ -2540,11 +2576,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "argsort", - [](const array& a, std::optional axis, StreamOrDevice s) { + [](const mx::array& a, std::optional axis, mx::StreamOrDevice s) { if (axis) { - return argsort(a, *axis, s); + return mx::argsort(a, *axis, s); } else { - return argsort(a, s); + return mx::argsort(a, s); } }, nb::arg(), @@ -2567,11 +2603,14 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "partition", - [](const array& a, int kth, std::optional axis, StreamOrDevice s) { + [](const mx::array& a, + int kth, + std::optional axis, + mx::StreamOrDevice s) { if (axis) { - return partition(a, kth, *axis, s); + return mx::partition(a, kth, *axis, s); } else { - return partition(a, kth, s); + return mx::partition(a, kth, s); } }, nb::arg(), @@ -2602,11 +2641,14 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "argpartition", - [](const array& a, int kth, std::optional axis, StreamOrDevice s) { + [](const mx::array& a, + int kth, + std::optional axis, + mx::StreamOrDevice s) { if (axis) { - return argpartition(a, kth, *axis, s); + return mx::argpartition(a, kth, *axis, s); } else { - return argpartition(a, kth, s); + return mx::argpartition(a, kth, s); } }, nb::arg(), @@ -2638,11 +2680,14 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "topk", - [](const array& a, int k, std::optional axis, StreamOrDevice s) { + [](const mx::array& a, + int k, + std::optional axis, + mx::StreamOrDevice s) { if (axis) { - return topk(a, k, *axis, s); + return mx::topk(a, k, *axis, s); } else { - return topk(a, k, s); + return mx::topk(a, k, s); } }, nb::arg(), @@ -2671,7 +2716,9 @@ void init_ops(nb::module_& m) { "broadcast_to", [](const ScalarOrArray& a, const std::vector& shape, - StreamOrDevice s) { return broadcast_to(to_array(a), shape, s); }, + mx::StreamOrDevice s) { + return mx::broadcast_to(to_array(a), shape, s); + }, nb::arg(), "shape"_a, nb::kw_only(), @@ -2692,8 +2739,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "softmax", - [](const array& a, const IntOrVec& axis, bool precise, StreamOrDevice s) { - return softmax(a, get_reduce_axes(axis, a.ndim()), precise, s); + [](const mx::array& a, + const IntOrVec& axis, + bool precise, + mx::StreamOrDevice s) { + return mx::softmax(a, get_reduce_axes(axis, a.ndim()), precise, s); }, nb::arg(), "axis"_a = nb::none(), @@ -2722,13 +2772,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "concatenate", - [](const std::vector& arrays, + [](const std::vector& arrays, std::optional axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return concatenate(arrays, *axis, s); + return mx::concatenate(arrays, *axis, s); } else { - return concatenate(arrays, s); + return mx::concatenate(arrays, s); } }, nb::arg(), @@ -2750,13 +2800,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "concat", - [](const std::vector& arrays, + [](const std::vector& arrays, std::optional axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return concatenate(arrays, *axis, s); + return mx::concatenate(arrays, *axis, s); } else { - return concatenate(arrays, s); + return mx::concatenate(arrays, s); } }, nb::arg(), @@ -2770,13 +2820,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "stack", - [](const std::vector& arrays, + [](const std::vector& arrays, std::optional axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis.has_value()) { - return stack(arrays, axis.value(), s); + return mx::stack(arrays, axis.value(), s); } else { - return stack(arrays, s); + return mx::stack(arrays, s); } }, nb::arg(), @@ -2802,9 +2852,10 @@ void init_ops(nb::module_& m) { [](nb::args arrays_, bool sparse, std::string indexing, - StreamOrDevice s) { - std::vector arrays = nb::cast>(arrays_); - return meshgrid(arrays, sparse, indexing, s); + mx::StreamOrDevice s) { + std::vector arrays = + nb::cast>(arrays_); + return mx::meshgrid(arrays, sparse, indexing, s); }, "arrays"_a, "sparse"_a = false, @@ -2828,14 +2879,14 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "repeat", - [](const array& array, + [](const mx::array& array, int repeats, std::optional axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis.has_value()) { - return repeat(array, repeats, axis.value(), s); + return mx::repeat(array, repeats, axis.value(), s); } else { - return repeat(array, repeats, s); + return mx::repeat(array, repeats, s); } }, nb::arg(), @@ -2861,19 +2912,19 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "clip", - [](const array& a, + [](const mx::array& a, const std::optional& min, const std::optional& max, - StreamOrDevice s) { - std::optional min_ = std::nullopt; - std::optional max_ = std::nullopt; + mx::StreamOrDevice s) { + std::optional min_ = std::nullopt; + std::optional max_ = std::nullopt; if (min) { min_ = to_arrays(a, min.value()).second; } if (max) { max_ = to_arrays(a, max.value()).second; } - return clip(a, min_, max_, s); + return mx::clip(a, min_, max_, s); }, nb::arg(), "a_min"_a.none(), @@ -2899,7 +2950,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "pad", - [](const array& a, + [](const mx::array& a, const std::variant< int, std::tuple, @@ -2907,19 +2958,20 @@ void init_ops(nb::module_& m) { std::vector>>& pad_width, const std::string mode, const ScalarOrArray& constant_value, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (auto pv = std::get_if(&pad_width); pv) { - return pad(a, *pv, to_array(constant_value), mode, s); + return mx::pad(a, *pv, to_array(constant_value), mode, s); } else if (auto pv = std::get_if>(&pad_width); pv) { - return pad(a, std::get<0>(*pv), to_array(constant_value), mode, s); + return mx::pad( + a, std::get<0>(*pv), to_array(constant_value), mode, s); } else if (auto pv = std::get_if>(&pad_width); pv) { - return pad(a, *pv, to_array(constant_value), mode, s); + return mx::pad(a, *pv, to_array(constant_value), mode, s); } else { auto v = std::get>>(pad_width); if (v.size() == 1) { - return pad(a, v[0], to_array(constant_value), mode, s); + return mx::pad(a, v[0], to_array(constant_value), mode, s); } else { - return pad(a, v, to_array(constant_value), mode, s); + return mx::pad(a, v, to_array(constant_value), mode, s); } } }, @@ -2953,22 +3005,22 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "as_strided", - [](const array& a, - std::optional shape, - std::optional strides, + [](const mx::array& a, + std::optional shape, + std::optional strides, size_t offset, - StreamOrDevice s) { + mx::StreamOrDevice s) { auto a_shape = (shape) ? *shape : a.shape(); - Strides a_strides; + mx::Strides a_strides; if (strides) { a_strides = *strides; } else { - a_strides = Strides(a_shape.size(), 1); + a_strides = mx::Strides(a_shape.size(), 1); for (int i = a_shape.size() - 1; i > 0; i--) { a_strides[i - 1] = a_shape[i] * a_strides[i]; } } - return as_strided(a, a_shape, a_strides, offset, s); + return mx::as_strided(a, a_shape, a_strides, offset, s); }, nb::arg(), "shape"_a = nb::none(), @@ -3006,15 +3058,15 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "cumsum", - [](const array& a, + [](const mx::array& a, std::optional axis, bool reverse, bool inclusive, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return cumsum(a, *axis, reverse, inclusive, s); + return mx::cumsum(a, *axis, reverse, inclusive, s); } else { - return cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cumsum(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, nb::arg(), @@ -3042,15 +3094,15 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "cumprod", - [](const array& a, + [](const mx::array& a, std::optional axis, bool reverse, bool inclusive, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return cumprod(a, *axis, reverse, inclusive, s); + return mx::cumprod(a, *axis, reverse, inclusive, s); } else { - return cumprod(reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cumprod(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, nb::arg(), @@ -3078,15 +3130,15 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "cummax", - [](const array& a, + [](const mx::array& a, std::optional axis, bool reverse, bool inclusive, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return cummax(a, *axis, reverse, inclusive, s); + return mx::cummax(a, *axis, reverse, inclusive, s); } else { - return cummax(reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cummax(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, nb::arg(), @@ -3114,15 +3166,15 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "cummin", - [](const array& a, + [](const mx::array& a, std::optional axis, bool reverse, bool inclusive, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (axis) { - return cummin(a, *axis, reverse, inclusive, s); + return mx::cummin(a, *axis, reverse, inclusive, s); } else { - return cummin(reshape(a, {-1}, s), 0, reverse, inclusive, s); + return mx::cummin(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); } }, nb::arg(), @@ -3150,8 +3202,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "conj", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::conjugate(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::conjugate(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -3170,8 +3222,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "conjugate", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::conjugate(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::conjugate(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -3190,10 +3242,10 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "convolve", - [](const array& a, - const array& v, + [](const mx::array& a, + const mx::array& v, const std::string& mode, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (a.ndim() != 1 || v.ndim() != 1) { throw std::invalid_argument("[convolve] Inputs must be 1D."); } @@ -3202,12 +3254,12 @@ void init_ops(nb::module_& m) { throw std::invalid_argument("[convolve] Inputs cannot be empty."); } - array in = a.size() < v.size() ? v : a; - array wt = a.size() < v.size() ? a : v; - wt = slice(wt, {wt.shape(0) - 1}, {-wt.shape(0) - 1}, {-1}, s); + mx::array in = a.size() < v.size() ? v : a; + mx::array wt = a.size() < v.size() ? a : v; + wt = mx::slice(wt, {wt.shape(0) - 1}, {-wt.shape(0) - 1}, {-1}, s); - in = reshape(in, {1, -1, 1}, s); - wt = reshape(wt, {1, -1, 1}, s); + in = mx::reshape(in, {1, -1, 1}, s); + wt = mx::reshape(wt, {1, -1, 1}, s); int padding = 0; @@ -3222,15 +3274,19 @@ void init_ops(nb::module_& m) { } else { // Even sizes use asymmetric padding int pad_l = wt.size() / 2; int pad_r = std::max(0, pad_l - 1); - in = pad( - in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), "constant", s); + in = mx::pad( + in, + {{0, 0}, {pad_l, pad_r}, {0, 0}}, + mx::array(0), + "constant", + s); } } else { throw std::invalid_argument("[convolve] Invalid mode."); } - array out = conv1d( + mx::array out = mx::conv1d( in, wt, /*stride = */ 1, @@ -3239,7 +3295,7 @@ void init_ops(nb::module_& m) { /*groups = */ 1, s); - return reshape(out, {-1}, s); + return mx::reshape(out, {-1}, s); }, nb::arg(), nb::arg(), @@ -3264,7 +3320,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "conv1d", - &conv1d, + &mx::conv1d, nb::arg(), nb::arg(), "stride"_a = 1, @@ -3291,13 +3347,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "conv2d", - [](const array& input, - const array& weight, + [](const mx::array& input, + const mx::array& weight, const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, int groups, - StreamOrDevice s) { + mx::StreamOrDevice s) { std::pair stride_pair{1, 1}; std::pair padding_pair{0, 0}; std::pair dilation_pair{1, 1}; @@ -3320,7 +3376,7 @@ void init_ops(nb::module_& m) { dilation_pair = std::get>(dilation); } - return conv2d( + return mx::conv2d( input, weight, stride_pair, padding_pair, dilation_pair, groups, s); }, nb::arg(), @@ -3355,13 +3411,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "conv3d", - [](const array& input, - const array& weight, + [](const mx::array& input, + const mx::array& weight, const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, int groups, - StreamOrDevice s) { + mx::StreamOrDevice s) { std::tuple stride_tuple{1, 1, 1}; std::tuple padding_tuple{0, 0, 0}; std::tuple dilation_tuple{1, 1, 1}; @@ -3384,7 +3440,7 @@ void init_ops(nb::module_& m) { dilation_tuple = std::get>(dilation); } - return conv3d( + return mx::conv3d( input, weight, stride_tuple, @@ -3427,7 +3483,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "conv_transpose1d", - &conv_transpose1d, + &mx::conv_transpose1d, nb::arg(), nb::arg(), "stride"_a = 1, @@ -3454,13 +3510,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "conv_transpose2d", - [](const array& input, - const array& weight, + [](const mx::array& input, + const mx::array& weight, const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, int groups, - StreamOrDevice s) { + mx::StreamOrDevice s) { std::pair stride_pair{1, 1}; std::pair padding_pair{0, 0}; std::pair dilation_pair{1, 1}; @@ -3483,7 +3539,7 @@ void init_ops(nb::module_& m) { dilation_pair = std::get>(dilation); } - return conv_transpose2d( + return mx::conv_transpose2d( input, weight, stride_pair, padding_pair, dilation_pair, groups, s); }, nb::arg(), @@ -3520,13 +3576,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "conv_transpose3d", - [](const array& input, - const array& weight, + [](const mx::array& input, + const mx::array& weight, const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, int groups, - StreamOrDevice s) { + mx::StreamOrDevice s) { std::tuple stride_tuple{1, 1, 1}; std::tuple padding_tuple{0, 0, 0}; std::tuple dilation_tuple{1, 1, 1}; @@ -3549,7 +3605,7 @@ void init_ops(nb::module_& m) { dilation_tuple = std::get>(dilation); } - return conv_transpose3d( + return mx::conv_transpose3d( input, weight, stride_tuple, @@ -3592,8 +3648,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "conv_general", - [](const array& input, - const array& weight, + [](const mx::array& input, + const mx::array& weight, const std::variant>& stride, const std::variant< int, @@ -3603,7 +3659,7 @@ void init_ops(nb::module_& m) { const std::variant>& input_dilation, int groups, bool flip, - StreamOrDevice s) { + mx::StreamOrDevice s) { std::vector stride_vec; std::vector padding_lo_vec; std::vector padding_hi_vec; @@ -3641,7 +3697,7 @@ void init_ops(nb::module_& m) { input_dilation_vec = std::get>(input_dilation); } - return conv_general( + return mx::conv_general( /* array input = */ std::move(input), /* array weight = */ std::move(weight), /* std::vector stride = */ std::move(stride_vec), @@ -3842,9 +3898,9 @@ void init_ops(nb::module_& m) { [](const ScalarOrArray& condition, const ScalarOrArray& x_, const ScalarOrArray& y_, - StreamOrDevice s) { + mx::StreamOrDevice s) { auto [x, y] = to_arrays(x_, y_); - return where(to_array(condition), x, y, s); + return mx::where(to_array(condition), x, y, s); }, "condition"_a, nb::arg(), @@ -3874,8 +3930,8 @@ void init_ops(nb::module_& m) { float nan, std::optional& posinf, std::optional& neginf, - StreamOrDevice s) { - return nan_to_num(to_array(a), nan, posinf, neginf, s); + mx::StreamOrDevice s) { + return mx::nan_to_num(to_array(a), nan, posinf, neginf, s); }, nb::arg(), "nan"_a = 0.0f, @@ -3903,8 +3959,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "round", - [](const ScalarOrArray& a, int decimals, StreamOrDevice s) { - return round(to_array(a), decimals, s); + [](const ScalarOrArray& a, int decimals, mx::StreamOrDevice s) { + return mx::round(to_array(a), decimals, s); }, nb::arg(), "decimals"_a = 0, @@ -3932,7 +3988,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "quantized_matmul", - &quantized_matmul, + &mx::quantized_matmul, nb::arg(), nb::arg(), "scales"_a, @@ -3968,7 +4024,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "quantize", - &quantize, + &mx::quantize, nb::arg(), "group_size"_a = 64, "bits"_a = 4, @@ -4027,7 +4083,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "dequantize", - &dequantize, + &mx::dequantize, nb::arg(), "scales"_a, "biases"_a, @@ -4063,7 +4119,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "gather_qmm", - &gather_qmm, + &mx::gather_qmm, nb::arg(), nb::arg(), "scales"_a, @@ -4109,19 +4165,19 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "tensordot", - [](const array& a, - const array& b, + [](const mx::array& a, + const mx::array& b, const std::variant>>& axes, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (auto pv = std::get_if(&axes); pv) { - return tensordot(a, b, *pv, s); + return mx::tensordot(a, b, *pv, s); } else { auto& x = std::get>>(axes); if (x.size() != 2) { throw std::invalid_argument( "[tensordot] axes must be a list of two lists."); } - return tensordot(a, b, x[0], x[1], s); + return mx::tensordot(a, b, x[0], x[1], s); } }, nb::arg(), @@ -4148,7 +4204,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "inner", - &inner, + &mx::inner, nb::arg(), nb::arg(), nb::kw_only(), @@ -4167,7 +4223,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "outer", - &outer, + &mx::outer, nb::arg(), nb::arg(), nb::kw_only(), @@ -4186,13 +4242,13 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "tile", - [](const array& a, + [](const mx::array& a, const std::variant>& reps, - StreamOrDevice s) { + mx::StreamOrDevice s) { if (auto pv = std::get_if(&reps); pv) { - return tile(a, {*pv}, s); + return mx::tile(a, {*pv}, s); } else { - return tile(a, std::get>(reps), s); + return mx::tile(a, std::get>(reps), s); } }, nb::arg(), @@ -4213,7 +4269,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "addmm", - &addmm, + &mx::addmm, nb::arg(), nb::arg(), nb::arg(), @@ -4242,7 +4298,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "block_masked_mm", - &block_masked_mm, + &mx::block_masked_mm, nb::arg(), nb::arg(), "block_size"_a = 64, @@ -4282,7 +4338,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "gather_mm", - &gather_mm, + &mx::gather_mm, nb::arg(), nb::arg(), "lhs_indices"_a = nb::none(), @@ -4320,7 +4376,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "diagonal", - &diagonal, + &mx::diagonal, "a"_a, "offset"_a = 0, "axis1"_a = 0, @@ -4353,7 +4409,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "diag", - &diag, + &mx::diag, nb::arg(), "k"_a = 0, nb::kw_only(), @@ -4376,16 +4432,16 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "trace", - [](const array& a, + [](const mx::array& a, int offset, int axis1, int axis2, - std::optional dtype, - StreamOrDevice s) { + std::optional dtype, + mx::StreamOrDevice s) { if (!dtype.has_value()) { - return trace(a, offset, axis1, axis2, s); + return mx::trace(a, offset, axis1, axis2, s); } - return trace(a, offset, axis1, axis2, dtype.value(), s); + return mx::trace(a, offset, axis1, axis2, dtype.value(), s); }, nb::arg(), "offset"_a = 0, @@ -4415,11 +4471,12 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "atleast_1d", - [](const nb::args& arys, StreamOrDevice s) -> nb::object { + [](const nb::args& arys, mx::StreamOrDevice s) -> nb::object { if (arys.size() == 1) { - return nb::cast(atleast_1d(nb::cast(arys[0]), s)); + return nb::cast(mx::atleast_1d(nb::cast(arys[0]), s)); } - return nb::cast(atleast_1d(nb::cast>(arys), s)); + return nb::cast( + mx::atleast_1d(nb::cast>(arys), s)); }, "arys"_a, "stream"_a = nb::none(), @@ -4437,11 +4494,12 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "atleast_2d", - [](const nb::args& arys, StreamOrDevice s) -> nb::object { + [](const nb::args& arys, mx::StreamOrDevice s) -> nb::object { if (arys.size() == 1) { - return nb::cast(atleast_2d(nb::cast(arys[0]), s)); + return nb::cast(mx::atleast_2d(nb::cast(arys[0]), s)); } - return nb::cast(atleast_2d(nb::cast>(arys), s)); + return nb::cast( + mx::atleast_2d(nb::cast>(arys), s)); }, "arys"_a, "stream"_a = nb::none(), @@ -4459,11 +4517,12 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "atleast_3d", - [](const nb::args& arys, StreamOrDevice s) -> nb::object { + [](const nb::args& arys, mx::StreamOrDevice s) -> nb::object { if (arys.size() == 1) { - return nb::cast(atleast_3d(nb::cast(arys[0]), s)); + return nb::cast(mx::atleast_3d(nb::cast(arys[0]), s)); } - return nb::cast(atleast_3d(nb::cast>(arys), s)); + return nb::cast( + mx::atleast_3d(nb::cast>(arys), s)); }, "arys"_a, "stream"_a = nb::none(), @@ -4483,19 +4542,19 @@ void init_ops(nb::module_& m) { "issubdtype", [](const nb::object& d1, const nb::object& d2) { auto dispatch_second = [](const auto& t1, const auto& d2) { - if (nb::isinstance(d2)) { - return issubdtype(t1, nb::cast(d2)); - } else if (nb::isinstance(d2)) { - return issubdtype(t1, nb::cast(d2)); + if (nb::isinstance(d2)) { + return mx::issubdtype(t1, nb::cast(d2)); + } else if (nb::isinstance(d2)) { + return mx::issubdtype(t1, nb::cast(d2)); } else { throw std::invalid_argument( "[issubdtype] Received invalid type for second input."); } }; - if (nb::isinstance(d1)) { - return dispatch_second(nb::cast(d1), d2); - } else if (nb::isinstance(d1)) { - return dispatch_second(nb::cast(d1), d2); + if (nb::isinstance(d1)) { + return dispatch_second(nb::cast(d1), d2); + } else if (nb::isinstance(d1)) { + return dispatch_second(nb::cast(d1), d2); } else { throw std::invalid_argument( "[issubdtype] Received invalid type for first input."); @@ -4555,9 +4614,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "bitwise_and", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return bitwise_and(a, b, s); + return mx::bitwise_and(a, b, s); }, nb::arg(), nb::arg(), @@ -4580,9 +4641,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "bitwise_or", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return bitwise_or(a, b, s); + return mx::bitwise_or(a, b, s); }, nb::arg(), nb::arg(), @@ -4605,9 +4668,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "bitwise_xor", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return bitwise_xor(a, b, s); + return mx::bitwise_xor(a, b, s); }, nb::arg(), nb::arg(), @@ -4631,9 +4696,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "left_shift", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return left_shift(a, b, s); + return mx::left_shift(a, b, s); }, nb::arg(), nb::arg(), @@ -4657,9 +4724,11 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "right_shift", - [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + mx::StreamOrDevice s) { auto [a, b] = to_arrays(a_, b_); - return right_shift(a, b, s); + return mx::right_shift(a, b, s); }, nb::arg(), nb::arg(), @@ -4683,8 +4752,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "view", - [](const ScalarOrArray& a, const Dtype& dtype, StreamOrDevice s) { - return view(to_array(a), dtype, s); + [](const ScalarOrArray& a, const mx::Dtype& dtype, mx::StreamOrDevice s) { + return mx::view(to_array(a), dtype, s); }, nb::arg(), "dtype"_a, @@ -4711,7 +4780,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "hadamard_transform", - &hadamard_transform, + &mx::hadamard_transform, nb::arg(), "scale"_a = nb::none(), nb::kw_only(), @@ -4743,8 +4812,8 @@ void init_ops(nb::module_& m) { m.def( "einsum_path", [](const std::string& equation, const nb::args& operands) { - auto arrays_list = nb::cast>(operands); - auto [path, str] = einsum_path(equation, arrays_list); + auto arrays_list = nb::cast>(operands); + auto [path, str] = mx::einsum_path(equation, arrays_list); // Convert to list of tuples std::vector tuple_path; for (auto& p : path) { @@ -4772,9 +4841,9 @@ void init_ops(nb::module_& m) { "einsum", [](const std::string& subscripts, const nb::args& operands, - StreamOrDevice s) { - auto arrays_list = nb::cast>(operands); - return einsum(subscripts, arrays_list, s); + mx::StreamOrDevice s) { + auto arrays_list = nb::cast>(operands); + return mx::einsum(subscripts, arrays_list, s); }, "subscripts"_a, "operands"_a, @@ -4795,12 +4864,12 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "roll", - [](const array& a, + [](const mx::array& a, const IntOrVec& shift, const IntOrVec& axis, - StreamOrDevice s) { + mx::StreamOrDevice s) { return std::visit( - [&](auto sh, auto ax) -> array { + [&](auto sh, auto ax) -> mx::array { using T = decltype(ax); using V = decltype(sh); @@ -4809,9 +4878,9 @@ void init_ops(nb::module_& m) { "[roll] Expected two arguments but only one was given."); } else { if constexpr (std::is_same_v) { - return roll(a, sh, s); + return mx::roll(a, sh, s); } else { - return roll(a, sh, ax, s); + return mx::roll(a, sh, ax, s); } } }, @@ -4845,8 +4914,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "real", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::real(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::real(to_array(a), s); }, nb::arg(), nb::kw_only(), @@ -4864,8 +4933,8 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "imag", - [](const ScalarOrArray& a, StreamOrDevice s) { - return mlx::core::imag(to_array(a), s); + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::imag(to_array(a), s); }, nb::arg(), nb::kw_only(), diff --git a/python/src/random.cpp b/python/src/random.cpp index 7b3200764..b67dfc219 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -12,23 +12,22 @@ #include "mlx/ops.h" #include "mlx/random.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; -using namespace mlx::core::random; class PyKeySequence { public: explicit PyKeySequence(uint64_t seed) { - state_.append(key(seed)); + state_.append(mx::random::key(seed)); } void seed(uint64_t seed) { - state_[0] = key(seed); + state_[0] = mx::random::key(seed); } - array next() { - auto out = split(nb::cast(state_[0])); + mx::array next() { + auto out = mx::random::split(nb::cast(state_[0])); state_[0] = out.first; return out.second; } @@ -75,7 +74,7 @@ void init_random(nb::module_& parent_module) { )pbdoc"); m.def( "key", - &key, + &mx::random::key, "seed"_a, R"pbdoc( Get a PRNG key from a seed. @@ -88,7 +87,8 @@ void init_random(nb::module_& parent_module) { )pbdoc"); m.def( "split", - nb::overload_cast(&random::split), + nb::overload_cast( + &mx::random::split), "key"_a, "num"_a = 2, "stream"_a = nb::none(), @@ -109,22 +109,22 @@ void init_random(nb::module_& parent_module) { [](const ScalarOrArray& low, const ScalarOrArray& high, const std::vector& shape, - std::optional type, - const std::optional& key_, - StreamOrDevice s) { + std::optional type, + const std::optional& key_, + mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); - return uniform( + return mx::random::uniform( to_array(low), to_array(high), shape, - type.value_or(float32), + type.value_or(mx::float32), key, s); }, "low"_a = 0, "high"_a = 1, "shape"_a = std::vector{}, - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, "key"_a = nb::none(), "stream"_a = nb::none(), nb::sig( @@ -151,16 +151,17 @@ void init_random(nb::module_& parent_module) { m.def( "normal", [](const std::vector& shape, - std::optional type, + std::optional type, float loc, float scale, - const std::optional& key_, - StreamOrDevice s) { + const std::optional& key_, + mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); - return normal(shape, type.value_or(float32), loc, scale, key, s); + return mx::random::normal( + shape, type.value_or(mx::float32), loc, scale, key, s); }, "shape"_a = std::vector{}, - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, "loc"_a = 0.0, "scale"_a = 1.0, "key"_a = nb::none(), @@ -182,20 +183,20 @@ void init_random(nb::module_& parent_module) { )pbdoc"); m.def( "multivariate_normal", - [](const array& mean, - const array& cov, + [](const mx::array& mean, + const mx::array& cov, const std::vector& shape, - std::optional type, - const std::optional& key_, - StreamOrDevice s) { + std::optional type, + const std::optional& key_, + mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); - return multivariate_normal( - mean, cov, shape, type.value_or(float32), key, s); + return mx::random::multivariate_normal( + mean, cov, shape, type.value_or(mx::float32), key, s); }, "mean"_a, "cov"_a, "shape"_a = std::vector{}, - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, "key"_a = nb::none(), "stream"_a = nb::none(), nb::sig( @@ -227,17 +228,22 @@ void init_random(nb::module_& parent_module) { [](const ScalarOrArray& low, const ScalarOrArray& high, const std::vector& shape, - std::optional type, - const std::optional& key_, - StreamOrDevice s) { + std::optional type, + const std::optional& key_, + mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); - return randint( - to_array(low), to_array(high), shape, type.value_or(int32), key, s); + return mx::random::randint( + to_array(low), + to_array(high), + shape, + type.value_or(mx::int32), + key, + s); }, "low"_a, "high"_a, "shape"_a = std::vector{}, - "dtype"_a.none() = int32, + "dtype"_a.none() = mx::int32, "key"_a = nb::none(), "stream"_a = nb::none(), nb::sig( @@ -263,14 +269,14 @@ void init_random(nb::module_& parent_module) { "bernoulli", [](const ScalarOrArray& p_, const std::optional> shape, - const std::optional& key_, - StreamOrDevice s) { + const std::optional& key_, + mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); auto p = to_array(p_); if (shape.has_value()) { - return bernoulli(p, shape.value(), key, s); + return mx::random::bernoulli(p, shape.value(), key, s); } else { - return bernoulli(p, key, s); + return mx::random::bernoulli(p, key, s); } }, "p"_a = 0.5, @@ -301,23 +307,24 @@ void init_random(nb::module_& parent_module) { [](const ScalarOrArray& lower_, const ScalarOrArray& upper_, const std::optional> shape_, - std::optional type, - const std::optional& key_, - StreamOrDevice s) { + std::optional type, + const std::optional& key_, + mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); auto lower = to_array(lower_); auto upper = to_array(upper_); - auto t = type.value_or(float32); + auto t = type.value_or(mx::float32); if (shape_.has_value()) { - return truncated_normal(lower, upper, shape_.value(), t, key, s); + return mx::random::truncated_normal( + lower, upper, shape_.value(), t, key, s); } else { - return truncated_normal(lower, upper, t, key, s); + return mx::random::truncated_normal(lower, upper, t, key, s); } }, "lower"_a, "upper"_a, "shape"_a = nb::none(), - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, "key"_a = nb::none(), "stream"_a = nb::none(), nb::sig( @@ -344,14 +351,14 @@ void init_random(nb::module_& parent_module) { m.def( "gumbel", [](const std::vector& shape, - std::optional type, - const std::optional& key_, - StreamOrDevice s) { + std::optional type, + const std::optional& key_, + mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); - return gumbel(shape, type.value_or(float32), key, s); + return mx::random::gumbel(shape, type.value_or(mx::float32), key, s); }, "shape"_a = std::vector{}, - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, "key"_a = nb::none(), "stream"_a = nb::none(), nb::sig( @@ -375,22 +382,23 @@ void init_random(nb::module_& parent_module) { )pbdoc"); m.def( "categorical", - [](const array& logits, + [](const mx::array& logits, int axis, const std::optional> shape, const std::optional num_samples, - const std::optional& key_, - StreamOrDevice s) { + const std::optional& key_, + mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); if (shape.has_value() && num_samples.has_value()) { throw std::invalid_argument( "[categorical] At most one of shape or num_samples can be specified."); } else if (shape.has_value()) { - return categorical(logits, axis, shape.value(), key, s); + return mx::random::categorical(logits, axis, shape.value(), key, s); } else if (num_samples.has_value()) { - return categorical(logits, axis, num_samples.value(), key, s); + return mx::random::categorical( + logits, axis, num_samples.value(), key, s); } else { - return categorical(logits, axis, key, s); + return mx::random::categorical(logits, axis, key, s); } }, "logits"_a, @@ -427,16 +435,17 @@ void init_random(nb::module_& parent_module) { m.def( "laplace", [](const std::vector& shape, - std::optional type, + std::optional type, float loc, float scale, - const std::optional& key_, - StreamOrDevice s) { + const std::optional& key_, + mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); - return laplace(shape, type.value_or(float32), loc, scale, key, s); + return mx::random::laplace( + shape, type.value_or(mx::float32), loc, scale, key, s); }, "shape"_a = std::vector{}, - "dtype"_a.none() = float32, + "dtype"_a.none() = mx::float32, "loc"_a = 0.0, "scale"_a = 1.0, "key"_a = nb::none(), @@ -459,15 +468,15 @@ void init_random(nb::module_& parent_module) { )pbdoc"); m.def( "permuation", - [](const std::variant& x, + [](const std::variant& x, int axis, - const std::optional& key_, - StreamOrDevice s) { + const std::optional& key_, + mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); if (auto pv = std::get_if(&x); pv) { - return permutation(nb::cast(*pv), key, s); + return mx::random::permutation(nb::cast(*pv), key, s); } else { - return permutation(std::get(x), axis, key, s); + return mx::random::permutation(std::get(x), axis, key, s); } }, "shape"_a = std::vector{}, diff --git a/python/src/stream.cpp b/python/src/stream.cpp index e1eb6b953..e10f4751c 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -10,14 +10,14 @@ #include "mlx/stream.h" #include "mlx/utils.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; // Create the StreamContext on enter and delete on exit. class PyStreamContext { public: - PyStreamContext(StreamOrDevice s) : _inner(nullptr) { + PyStreamContext(mx::StreamOrDevice s) : _inner(nullptr) { if (std::holds_alternative(s)) { throw std::runtime_error( "[StreamContext] Invalid argument, please specify a stream or device."); @@ -26,7 +26,7 @@ class PyStreamContext { } void enter() { - _inner = new StreamContext(_s); + _inner = new mx::StreamContext(_s); } void exit() { @@ -37,39 +37,40 @@ class PyStreamContext { } private: - StreamOrDevice _s; - StreamContext* _inner; + mx::StreamOrDevice _s; + mx::StreamContext* _inner; }; void init_stream(nb::module_& m) { - nb::class_( + nb::class_( m, "Stream", R"pbdoc( A stream for running operations on a given device. )pbdoc") - .def_ro("device", &Stream::device) + .def_ro("device", &mx::Stream::device) .def( "__repr__", - [](const Stream& s) { + [](const mx::Stream& s) { std::ostringstream os; os << s; return os.str(); }) - .def("__eq__", [](const Stream& s, const nb::object& other) { - return nb::isinstance(other) && s == nb::cast(other); + .def("__eq__", [](const mx::Stream& s, const nb::object& other) { + return nb::isinstance(other) && + s == nb::cast(other); }); - nb::implicitly_convertible(); + nb::implicitly_convertible(); m.def( "default_stream", - &default_stream, + &mx::default_stream, "device"_a, R"pbdoc(Get the device's default stream.)pbdoc"); m.def( "set_default_stream", - &set_default_stream, + &mx::set_default_stream, "stream"_a, R"pbdoc( Set the default stream. @@ -82,7 +83,7 @@ void init_stream(nb::module_& m) { )pbdoc"); m.def( "new_stream", - &new_stream, + &mx::new_stream, "device"_a, R"pbdoc(Make a new stream on the given device.)pbdoc"); @@ -94,7 +95,7 @@ void init_stream(nb::module_& m) { Args: s: The stream or device to set as the default. )pbdoc") - .def(nb::init(), "s"_a) + .def(nb::init(), "s"_a) .def("__enter__", [](PyStreamContext& scm) { scm.enter(); }) .def( "__exit__", @@ -107,7 +108,7 @@ void init_stream(nb::module_& m) { "traceback"_a = nb::none()); m.def( "stream", - [](StreamOrDevice s) { return PyStreamContext(s); }, + [](mx::StreamOrDevice s) { return PyStreamContext(s); }, "s"_a, R"pbdoc( Create a context manager to set the default device and stream. @@ -131,8 +132,8 @@ void init_stream(nb::module_& m) { )pbdoc"); m.def( "synchronize", - [](const std::optional& s) { - s ? synchronize(s.value()) : synchronize(); + [](const std::optional& s) { + s ? mx::synchronize(s.value()) : mx::synchronize(); }, "stream"_a = nb::none(), R"pbdoc( diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 58425d949..c7fd5b4c7 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -20,9 +20,12 @@ #include "mlx/utils.h" #include "python/src/trees.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; + +// Needed for printing shapes and strides. +using mx::operator<<; using IntOrVec = std::variant>; using StrOrVec = std::variant>; @@ -108,7 +111,7 @@ auto py_value_and_grad( } // Collect the arrays - std::vector arrays; + std::vector arrays; std::vector counts(1, 0); for (auto i : argnums) { auto argsi = tree_flatten(args[i]); @@ -127,7 +130,7 @@ auto py_value_and_grad( // value_out will hold the output of the python function in order to be // able to reconstruct the python tree of extra return values nb::object py_value_out; - auto value_and_grads = value_and_grad( + auto value_and_grads = mx::value_and_grad( [&fun, &args, &kwargs, @@ -136,7 +139,7 @@ auto py_value_and_grad( &counts, &py_value_out, &error_msg_tag, - scalar_func_only](const std::vector& a) { + scalar_func_only](const std::vector& a) { // Copy the arguments nb::list args_cpy; nb::kwargs kwargs_cpy = nb::kwargs(); @@ -165,7 +168,7 @@ auto py_value_and_grad( py_value_out = fun(*args_cpy, **kwargs_cpy); // Validate the return value of the python function - if (!nb::isinstance(py_value_out)) { + if (!nb::isinstance(py_value_out)) { if (scalar_func_only) { std::ostringstream msg; msg << error_msg_tag << " The return value of the function " @@ -193,7 +196,7 @@ auto py_value_and_grad( << "we got an empty tuple."; throw std::invalid_argument(msg.str()); } - if (!nb::isinstance(ret[0])) { + if (!nb::isinstance(ret[0])) { std::ostringstream msg; msg << error_msg_tag << " The return value of the function " << "whose gradient we want to compute should be either a " @@ -275,12 +278,12 @@ auto py_vmap( {tree, axes}, [&flat_axes, &encountered_tuple, output_axes]( const std::vector& inputs) { - if (nb::isinstance(inputs[0])) { + if (nb::isinstance(inputs[0])) { if (inputs[1].is_none()) { flat_axes.push_back(-1); } else if (nb::isinstance(inputs[1])) { int axis = nb::cast(nb::cast(inputs[1])); - const array& x = nb::cast(inputs[0]); + const mx::array& x = nb::cast(inputs[0]); if (axis < 0) { axis += x.ndim() + output_axes; } @@ -297,7 +300,7 @@ auto py_vmap( auto l = nb::cast(inputs[1]); if (l.size() == 1 && nb::isinstance(l[0])) { int axis = nb::cast(nb::cast(l[0])); - const array& x = nb::cast(inputs[0]); + const mx::array& x = nb::cast(inputs[0]); if (axis < 0) { axis += x.ndim() + output_axes; } @@ -323,7 +326,7 @@ auto py_vmap( "[vmap] The arguments should contain only arrays"); } }); - if (encountered_tuple && !nb::isinstance(tree)) { + if (encountered_tuple && !nb::isinstance(tree)) { throw std::invalid_argument("[vmap] axis must be int or None."); } return flat_axes; @@ -339,7 +342,7 @@ auto py_vmap( nb::object py_outputs; auto vmap_fn = - [&fun, &args, &inputs, &py_outputs](const std::vector& a) { + [&fun, &args, &inputs, &py_outputs](const std::vector& a) { // Call the python function py_outputs = fun(*tree_unflatten(args, a)); @@ -348,12 +351,12 @@ auto py_vmap( }; auto [trace_inputs, trace_outputs] = - detail::vmap_trace(vmap_fn, inputs, flat_in_axes); + mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes); auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true); // Perform the vmap - auto outputs = detail::vmap_replace( + auto outputs = mx::detail::vmap_replace( inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes); // Put the outputs back in the container @@ -401,7 +404,7 @@ struct PyCompiledFun { nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { // Flat array inputs - std::vector inputs; + std::vector inputs; // Compilation constants which includes the tree structure of the arguments std::vector constants; @@ -437,8 +440,8 @@ struct PyCompiledFun { constants.push_back(nb::cast(r)); recurse(item.second); } - } else if (nb::isinstance(obj)) { - inputs.push_back(nb::cast(obj)); + } else if (nb::isinstance(obj)) { + inputs.push_back(nb::cast(obj)); constants.push_back(array_identifier); } else if (nb::isinstance(obj)) { auto r = obj.attr("__hash__")(); @@ -461,10 +464,10 @@ struct PyCompiledFun { int num_args = inputs.size(); recurse(kwargs); auto compile_fun = [this, &args, &kwargs, num_args]( - const std::vector& a) { + const std::vector& a) { // Put tracers into captured inputs - std::vector flat_in_captures; - std::vector trace_captures; + std::vector flat_in_captures; + std::vector trace_captures; if (!captured_inputs.is_none()) { flat_in_captures = tree_flatten(captured_inputs, false); trace_captures.insert( @@ -505,9 +508,9 @@ struct PyCompiledFun { // Compile and call auto outputs = - detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); + mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); if (!captured_outputs.is_none()) { - std::vector captures( + std::vector captures( std::make_move_iterator(outputs.begin() + num_outputs), std::make_move_iterator(outputs.end())); tree_fill(captured_outputs, captures); @@ -526,7 +529,7 @@ struct PyCompiledFun { nb::gil_scoped_acquire gil; tree_cache().erase(fun_id); - detail::compile_erase(fun_id); + mx::detail::compile_erase(fun_id); fun.release().dec_ref(); captured_inputs.release().dec_ref(); captured_outputs.release().dec_ref(); @@ -561,7 +564,7 @@ class PyCheckpointedFun { args_structure_.release().dec_ref(); } - std::vector operator()(const std::vector& inputs) { + std::vector operator()(const std::vector& inputs) { auto args = nb::cast( tree_unflatten_from_structure(args_structure_, inputs)); auto [outputs, output_structure] = @@ -579,7 +582,7 @@ class PyCheckpointedFun { auto [inputs, args_structure] = tree_flatten_with_structure(full_args, false); - auto outputs = checkpoint( + auto outputs = mx::checkpoint( InnerFunction(fun_, args_structure, output_structure))(inputs); return tree_unflatten_from_structure(*output_structure, outputs); @@ -660,12 +663,12 @@ class PyCustomFunction { } } - std::vector operator()(const std::vector& inputs) { + std::vector operator()(const std::vector& inputs) { nb::gil_scoped_acquire gil; auto new_inputs = nb::cast( tree_unflatten_from_structure(input_structure_, inputs)); - std::vector outputs; + std::vector outputs; std::tie(outputs, *output_structure_) = tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1])); return outputs; @@ -694,10 +697,10 @@ class PyCustomFunction { } } - std::vector operator()( - const std::vector& primals, - const std::vector& cotangents, - const std::vector& outputs) { + std::vector operator()( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& outputs) { nb::gil_scoped_acquire gil; auto new_inputs = nb::cast( @@ -734,9 +737,9 @@ class PyCustomFunction { input_structure_.release().dec_ref(); } - std::vector operator()( - const std::vector& primals, - const std::vector& tangents, + std::vector operator()( + const std::vector& primals, + const std::vector& tangents, const std::vector& argnums) { nb::gil_scoped_acquire gil; @@ -759,7 +762,7 @@ class PyCustomFunction { int tangent_index = 0; auto new_tangents = nb::cast(tree_map(args, [&](nb::handle element) { - if (nb::isinstance(element) && + if (nb::isinstance(element) && have_tangents[array_index++]) { return nb::cast(tangents[tangent_index++]); } else { @@ -789,8 +792,8 @@ class PyCustomFunction { input_structure_.release().dec_ref(); } - std::pair, std::vector> operator()( - const std::vector& inputs, + std::pair, std::vector> operator()( + const std::vector& inputs, const std::vector& axes) { nb::gil_scoped_acquire gil; @@ -807,7 +810,7 @@ class PyCustomFunction { auto new_axes = nb::cast(tree_map(args, [&](nb::handle element) { int axis = axes[arr_index++]; - if (nb::isinstance(element) && axis >= 0) { + if (nb::isinstance(element) && axis >= 0) { return nb::cast(axis); } else { return nb::none(); @@ -831,11 +834,11 @@ class PyCustomFunction { "[custom vmap] Vmap function should return a tuple with 2 items."); } - std::vector outputs; + std::vector outputs; std::vector output_axes; tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) { - if (nb::isinstance(objects[0])) { - outputs.push_back(nb::cast(objects[0])); + if (nb::isinstance(objects[0])) { + outputs.push_back(nb::cast(objects[0])); output_axes.push_back( objects[1].is_none() ? -1 : nb::cast(objects[1])); } @@ -852,7 +855,7 @@ class PyCustomFunction { } // Extract the inputs and their structure in capturable vars - std::vector input_arrays; + std::vector input_arrays; nb::object input_structure; auto full_args = nb::make_tuple(args, kwargs); std::tie(input_arrays, input_structure) = @@ -864,7 +867,7 @@ class PyCustomFunction { // Make a function that calls fun_ in the forward pass and vjp_ in the // backward pass. Then call it immediately and return the results. - auto f = custom_function( + auto f = mx::custom_function( InnerFunction(fun_, input_structure, output_structure), make_vjp_function(input_structure, output_structure), make_jvp_function(input_structure), @@ -1044,7 +1047,7 @@ void init_transforms(nb::module_& m) { m.def( "eval", [](const nb::args& args) { - std::vector arrays = tree_flatten(args, false); + std::vector arrays = tree_flatten(args, false); { nb::gil_scoped_release nogil; eval(arrays); @@ -1064,7 +1067,7 @@ void init_transforms(nb::module_& m) { m.def( "async_eval", [](const nb::args& args) { - std::vector arrays = tree_flatten(args, false); + std::vector arrays = tree_flatten(args, false); { nb::gil_scoped_release nogil; async_eval(arrays); @@ -1100,14 +1103,14 @@ void init_transforms(nb::module_& m) { m.def( "jvp", [](const nb::callable& fun, - const std::vector& primals, - const std::vector& tangents) { - auto vfun = [&fun](const std::vector& primals) { + const std::vector& primals, + const std::vector& tangents) { + auto vfun = [&fun](const std::vector& primals) { auto out = fun(*nb::cast(primals)); - if (nb::isinstance(out)) { - return std::vector{nb::cast(out)}; + if (nb::isinstance(out)) { + return std::vector{nb::cast(out)}; } else { - return nb::cast>(out); + return nb::cast>(out); } }; return jvp(vfun, primals, tangents); @@ -1139,14 +1142,14 @@ void init_transforms(nb::module_& m) { m.def( "vjp", [](const nb::callable& fun, - const std::vector& primals, - const std::vector& cotangents) { - auto vfun = [&fun](const std::vector& primals) { + const std::vector& primals, + const std::vector& cotangents) { + auto vfun = [&fun](const std::vector& primals) { auto out = fun(*nb::cast(primals)); - if (nb::isinstance(out)) { - return std::vector{nb::cast(out)}; + if (nb::isinstance(out)) { + return std::vector{nb::cast(out)}; } else { - return nb::cast>(out); + return nb::cast>(out); } }; return vjp(vfun, primals, cotangents); @@ -1312,7 +1315,7 @@ void init_transforms(nb::module_& m) { m.def( "export_to_dot", [](nb::object file, const nb::args& args) { - std::vector arrays = tree_flatten(args); + std::vector arrays = tree_flatten(args); if (nb::isinstance(file)) { std::ofstream out(nb::cast(file)); export_to_dot(out, arrays); @@ -1399,14 +1402,14 @@ void init_transforms(nb::module_& m) { )pbdoc"); m.def( "disable_compile", - &disable_compile, + &mx::disable_compile, R"pbdoc( Globally disable compilation. Setting the environment variable ``MLX_DISABLE_COMPILE`` can also be used to disable compilation. )pbdoc"); m.def( "enable_compile", - &enable_compile, + &mx::enable_compile, R"pbdoc( Globally enable compilation. This will override the environment variable ``MLX_DISABLE_COMPILE`` if set. @@ -1420,6 +1423,6 @@ void init_transforms(nb::module_& m) { auto atexit = nb::module_::import_("atexit"); atexit.attr("register")(nb::cpp_function([]() { tree_cache().clear(); - detail::compile_clear_cache(); + mx::detail::compile_clear_cache(); })); } diff --git a/python/src/trees.cpp b/python/src/trees.cpp index b4ae53746..d9fe6d2d3 100644 --- a/python/src/trees.cpp +++ b/python/src/trees.cpp @@ -188,7 +188,7 @@ void tree_visit_update( d[item.first] = recurse(item.second); } return nb::cast(d); - } else if (nb::isinstance(subtree)) { + } else if (nb::isinstance(subtree)) { return visitor(subtree); } else { return nb::cast(subtree); @@ -200,7 +200,7 @@ void tree_visit_update( // Fill a pytree (recursive dict or list of dict or list) // in place with the given arrays // Non dict or list nodes are ignored -void tree_fill(nb::object& tree, const std::vector& values) { +void tree_fill(nb::object& tree, const std::vector& values) { size_t index = 0; tree_visit_update( tree, [&](nb::handle node) { return nb::cast(values[index++]); }); @@ -209,14 +209,14 @@ void tree_fill(nb::object& tree, const std::vector& values) { // Replace all the arrays from the src values with the dst values in the tree void tree_replace( nb::object& tree, - const std::vector& src, - const std::vector& dst) { - std::unordered_map src_to_dst; + const std::vector& src, + const std::vector& dst) { + std::unordered_map src_to_dst; for (int i = 0; i < src.size(); ++i) { src_to_dst.insert({src[i].id(), dst[i]}); } tree_visit_update(tree, [&](nb::handle node) { - auto arr = nb::cast(node); + auto arr = nb::cast(node); if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { return nb::cast(it->second); } @@ -224,12 +224,12 @@ void tree_replace( }); } -std::vector tree_flatten(nb::object tree, bool strict /* = true */) { - std::vector flat_tree; +std::vector tree_flatten(nb::object tree, bool strict /* = true */) { + std::vector flat_tree; tree_visit(tree, [&](nb::handle obj) { - if (nb::isinstance(obj)) { - flat_tree.push_back(nb::cast(obj)); + if (nb::isinstance(obj)) { + flat_tree.push_back(nb::cast(obj)); } else if (strict) { throw std::invalid_argument( "[tree_flatten] The argument should contain only arrays"); @@ -241,10 +241,10 @@ std::vector tree_flatten(nb::object tree, bool strict /* = true */) { nb::object tree_unflatten( nb::object tree, - const std::vector& values, + const std::vector& values, int index /* = 0 */) { return tree_map(tree, [&](nb::handle obj) { - if (nb::isinstance(obj)) { + if (nb::isinstance(obj)) { return nb::cast(values[index++]); } else { return nb::cast(obj); @@ -265,16 +265,16 @@ nb::object structure_sentinel() { return sentinel; } -std::pair, nb::object> tree_flatten_with_structure( +std::pair, nb::object> tree_flatten_with_structure( nb::object tree, bool strict /* = true */) { auto sentinel = structure_sentinel(); - std::vector flat_tree; + std::vector flat_tree; auto structure = tree_map( tree, [&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) { - if (nb::isinstance(obj)) { - flat_tree.push_back(nb::cast(obj)); + if (nb::isinstance(obj)) { + flat_tree.push_back(nb::cast(obj)); return sentinel; } else if (!strict) { return nb::cast(obj); @@ -289,7 +289,7 @@ std::pair, nb::object> tree_flatten_with_structure( nb::object tree_unflatten_from_structure( nb::object structure, - const std::vector& values, + const std::vector& values, int index /* = 0 */) { auto sentinel = structure_sentinel(); return tree_map(structure, [&](nb::handle obj) { diff --git a/python/src/trees.h b/python/src/trees.h index 931b3ea6b..fc146c29d 100644 --- a/python/src/trees.h +++ b/python/src/trees.h @@ -4,8 +4,8 @@ #include "mlx/array.h" +namespace mx = mlx::core; namespace nb = nanobind; -using namespace mlx::core; void tree_visit( const std::vector& trees, @@ -27,7 +27,7 @@ void tree_visit_update( /** * Fill a pytree (recursive dict or list of dict or list) in place with the * given arrays. */ -void tree_fill(nb::object& tree, const std::vector& values); +void tree_fill(nb::object& tree, const std::vector& values); /** * Replace all the arrays from the src values with the dst values in the @@ -35,28 +35,28 @@ void tree_fill(nb::object& tree, const std::vector& values); */ void tree_replace( nb::object& tree, - const std::vector& src, - const std::vector& dst); + const std::vector& src, + const std::vector& dst); /** * Flatten a tree into a vector of arrays. If strict is true, then the * function will throw if the tree contains a leaf which is not an array. */ -std::vector tree_flatten(nb::object tree, bool strict = true); +std::vector tree_flatten(nb::object tree, bool strict = true); /** * Unflatten a tree from a vector of arrays. */ nb::object tree_unflatten( nb::object tree, - const std::vector& values, + const std::vector& values, int index = 0); -std::pair, nb::object> tree_flatten_with_structure( +std::pair, nb::object> tree_flatten_with_structure( nb::object tree, bool strict = true); nb::object tree_unflatten_from_structure( nb::object structure, - const std::vector& values, + const std::vector& values, int index = 0); diff --git a/python/src/utils.cpp b/python/src/utils.cpp index 5d1118b80..959cd98a6 100644 --- a/python/src/utils.cpp +++ b/python/src/utils.cpp @@ -4,22 +4,24 @@ #include "mlx/ops.h" #include "python/src/convert.h" -array to_array( +mx::array to_array( const ScalarOrArray& v, - std::optional dtype /* = std::nullopt */) { + std::optional dtype /* = std::nullopt */) { if (auto pv = std::get_if(&v); pv) { - return array(nb::cast(*pv), dtype.value_or(bool_)); + return mx::array(nb::cast(*pv), dtype.value_or(mx::bool_)); } else if (auto pv = std::get_if(&v); pv) { - auto out_t = dtype.value_or(int32); + auto out_t = dtype.value_or(mx::int32); // bool_ is an exception and is always promoted - return array(nb::cast(*pv), (out_t == bool_) ? int32 : out_t); + return mx::array( + nb::cast(*pv), (out_t == mx::bool_) ? mx::int32 : out_t); } else if (auto pv = std::get_if(&v); pv) { - auto out_t = dtype.value_or(float32); - return array( - nb::cast(*pv), issubdtype(out_t, floating) ? out_t : float32); + auto out_t = dtype.value_or(mx::float32); + return mx::array( + nb::cast(*pv), + mx::issubdtype(out_t, mx::floating) ? out_t : mx::float32); } else if (auto pv = std::get_if>(&v); pv) { - return array(static_cast(*pv), complex64); - } else if (auto pv = std::get_if(&v); pv) { + return mx::array(static_cast(*pv), mx::complex64); + } else if (auto pv = std::get_if(&v); pv) { return *pv; } else if (auto pv = std::get_if< nb::ndarray>(&v); @@ -30,7 +32,7 @@ array to_array( } } -std::pair to_arrays( +std::pair to_arrays( const ScalarOrArray& a, const ScalarOrArray& b) { // Four cases: @@ -39,15 +41,15 @@ std::pair to_arrays( // - If b is an array but a is not, treat a as a weak python type // - If neither is an array convert to arrays but leave their types alone auto is_mlx_array = [](const ScalarOrArray& x) { - return std::holds_alternative(x) || + return std::holds_alternative(x) || std::holds_alternative(x) && nb::hasattr(std::get(x), "__mlx_array__"); }; auto get_mlx_array = [](const ScalarOrArray& x) { - if (auto px = std::get_if(&x); px) { + if (auto px = std::get_if(&x); px) { return *px; } else { - return nb::cast(std::get(x).attr("__mlx_array__")); + return nb::cast(std::get(x).attr("__mlx_array__")); } }; @@ -66,11 +68,11 @@ std::pair to_arrays( } } -array to_array_with_accessor(nb::object obj) { - if (nb::isinstance(obj)) { - return nb::cast(obj); +mx::array to_array_with_accessor(nb::object obj) { + if (nb::isinstance(obj)) { + return nb::cast(obj); } else if (nb::hasattr(obj, "__mlx_array__")) { - return nb::cast(obj.attr("__mlx_array__")()); + return nb::cast(obj.attr("__mlx_array__")()); } else { std::ostringstream msg; msg << "Invalid type " << nb::type_name(obj.type()).c_str() diff --git a/python/src/utils.h b/python/src/utils.h index 3d5b1af97..38e474746 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -12,17 +12,16 @@ #include "mlx/array.h" +namespace mx = mlx::core; namespace nb = nanobind; -using namespace mlx::core; - using IntOrVec = std::variant>; using ScalarOrArray = std::variant< nb::bool_, nb::int_, nb::float_, // Must be above ndarray - array, + mx::array, // Must be above complex nb::ndarray, std::complex, @@ -45,7 +44,7 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) { // Checks if the value can be compared to an array (or is already an // mlx array) if (auto pv = std::get_if(&v); pv) { - return nb::isinstance(*pv) || nb::hasattr(*pv, "__mlx_array__"); + return nb::isinstance(*pv) || nb::hasattr(*pv, "__mlx_array__"); } else { // If it's not an object, it's a scalar (nb::int_, nb::float_, etc.) // and can be compared to an array @@ -66,12 +65,12 @@ inline void throw_invalid_operation( throw std::invalid_argument(msg.str()); } -array to_array( +mx::array to_array( const ScalarOrArray& v, - std::optional dtype = std::nullopt); + std::optional dtype = std::nullopt); -std::pair to_arrays( +std::pair to_arrays( const ScalarOrArray& a, const ScalarOrArray& b); -array to_array_with_accessor(nb::object obj); +mx::array to_array_with_accessor(nb::object obj);