From 40c62c13213a480529a7ed8da62c5fd5148a5c1d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 9 Dec 2024 11:09:02 -0800 Subject: [PATCH] Use int64 stride everywhere (#1671) * use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more --- docs/src/dev/extensions.rst | 26 +- examples/extensions/axpby/axpby.metal | 31 +- mlx/array.h | 2 +- mlx/backend/common/arg_reduce.cpp | 4 +- mlx/backend/common/binary.h | 28 +- mlx/backend/common/binary_two.h | 14 +- mlx/backend/common/common.cpp | 29 +- mlx/backend/common/compiled.cpp | 2 +- mlx/backend/common/conv.cpp | 14 +- mlx/backend/common/copy.cpp | 85 ++--- mlx/backend/common/copy.h | 7 +- mlx/backend/common/default_primitives.cpp | 2 +- mlx/backend/common/indexing.cpp | 19 +- mlx/backend/common/masked_mm.cpp | 28 +- mlx/backend/common/primitives.cpp | 17 +- mlx/backend/common/qrf.cpp | 2 +- mlx/backend/common/reduce.cpp | 12 +- mlx/backend/common/reduce.h | 24 +- mlx/backend/common/reduce_utils.cpp | 22 +- mlx/backend/common/slicing.cpp | 14 +- mlx/backend/common/slicing.h | 8 +- mlx/backend/common/sort.cpp | 34 +- mlx/backend/common/ternary.h | 18 +- mlx/backend/common/utils.cpp | 74 +--- mlx/backend/common/utils.h | 77 ++-- mlx/backend/metal/binary.cpp | 4 +- mlx/backend/metal/compiled.cpp | 28 +- mlx/backend/metal/conv.cpp | 9 +- mlx/backend/metal/copy.cpp | 20 +- mlx/backend/metal/copy.h | 9 +- mlx/backend/metal/fft.cpp | 22 +- mlx/backend/metal/indexing.cpp | 16 +- mlx/backend/metal/jit/gemv_masked.h | 6 +- mlx/backend/metal/jit/indexing.h | 10 +- mlx/backend/metal/jit/steel_gemm.h | 6 +- mlx/backend/metal/kernels/arg_reduce.metal | 6 +- mlx/backend/metal/kernels/binary.h | 44 +-- mlx/backend/metal/kernels/binary_two.h | 44 +-- mlx/backend/metal/kernels/copy.h | 26 +- mlx/backend/metal/kernels/gather.h | 6 +- mlx/backend/metal/kernels/gemv.metal | 178 +++------ mlx/backend/metal/kernels/gemv_masked.h | 20 +- mlx/backend/metal/kernels/gemv_masked.metal | 56 +-- mlx/backend/metal/kernels/indexing.h | 2 +- mlx/backend/metal/kernels/quantized.h | 136 +++---- mlx/backend/metal/kernels/random.metal | 2 +- mlx/backend/metal/kernels/reduce.metal | 14 +- .../metal/kernels/reduction/reduce_col.h | 30 +- .../metal/kernels/reduction/reduce_row.h | 31 +- mlx/backend/metal/kernels/scatter.h | 11 +- mlx/backend/metal/kernels/sort.h | 6 +- mlx/backend/metal/kernels/steel/attn/params.h | 8 +- mlx/backend/metal/kernels/steel/conv/params.h | 8 +- .../steel/gemm/kernels/steel_gemm_fused.h | 23 +- .../steel/gemm/kernels/steel_gemm_fused.metal | 24 +- .../steel/gemm/kernels/steel_gemm_masked.h | 24 +- .../gemm/kernels/steel_gemm_masked.metal | 91 ++--- .../gemm/kernels/steel_gemm_splitk.metal | 104 ++--- mlx/backend/metal/kernels/steel/gemm/params.h | 8 +- mlx/backend/metal/kernels/steel/utils.h | 10 +- mlx/backend/metal/kernels/ternary.h | 52 +-- mlx/backend/metal/kernels/unary.h | 8 +- mlx/backend/metal/kernels/utils.h | 56 +-- mlx/backend/metal/matmul.cpp | 219 +++++------ mlx/backend/metal/matmul.h | 16 +- mlx/backend/metal/primitives.cpp | 25 +- mlx/backend/metal/reduce.cpp | 30 +- .../metal/scaled_dot_product_attention.cpp | 32 +- mlx/backend/metal/slicing.cpp | 15 +- mlx/backend/metal/slicing.h | 4 +- mlx/backend/metal/sort.cpp | 10 +- mlx/backend/metal/ternary.cpp | 4 +- mlx/backend/metal/unary.cpp | 2 +- mlx/backend/metal/utils.cpp | 10 +- mlx/backend/metal/utils.h | 10 +- mlx/einsum.cpp | 6 +- mlx/fast.cpp | 4 +- mlx/fft.cpp | 19 +- mlx/fft.h | 16 +- mlx/io/safetensors.cpp | 2 +- mlx/ops.cpp | 22 +- mlx/primitives.cpp | 14 +- mlx/primitives.h | 2 - mlx/random.cpp | 33 +- mlx/random.h | 44 +-- mlx/transforms.cpp | 4 +- mlx/utils.cpp | 35 +- mlx/utils.h | 26 +- python/src/convert.cpp | 25 +- python/src/ops.cpp | 10 +- tests/arg_reduce_tests.cpp | 6 +- tests/array_tests.cpp | 26 +- tests/autograd_tests.cpp | 12 +- tests/blas_tests.cpp | 2 +- tests/creations_tests.cpp | 8 +- tests/fft_tests.cpp | 38 +- tests/linalg_tests.cpp | 22 +- tests/load_tests.cpp | 4 +- tests/ops_tests.cpp | 358 +++++++++--------- tests/random_tests.cpp | 38 +- tests/utils_tests.cpp | 43 +-- tests/vmap_tests.cpp | 24 +- 102 files changed, 1264 insertions(+), 1707 deletions(-) diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 196f8bf65..c08614d03 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -420,8 +420,8 @@ element in the output. constant const float& alpha [[buffer(3)]], constant const float& beta [[buffer(4)]], constant const int* shape [[buffer(5)]], - constant const size_t* x_strides [[buffer(6)]], - constant const size_t* y_strides [[buffer(7)]], + constant const int64_t* x_strides [[buffer(6)]], + constant const int64_t* y_strides [[buffer(7)]], constant const int& ndim [[buffer(8)]], uint index [[thread_position_in_grid]]) { // Convert linear indices to offsets in array @@ -438,24 +438,10 @@ each instantiation a unique host name so we can identify it. .. code-block:: C++ - #define instantiate_axpby(type_name, type) \ - template [[host_name("axpby_general_" #type_name)]] \ - [[kernel]] void axpby_general( \ - device const type* x [[buffer(0)]], \ - device const type* y [[buffer(1)]], \ - device type* out [[buffer(2)]], \ - constant const float& alpha [[buffer(3)]], \ - constant const float& beta [[buffer(4)]], \ - constant const int* shape [[buffer(5)]], \ - constant const size_t* x_strides [[buffer(6)]], \ - constant const size_t* y_strides [[buffer(7)]], \ - constant const int& ndim [[buffer(8)]], \ - uint index [[thread_position_in_grid]]); - - instantiate_axpby(float32, float); - instantiate_axpby(float16, half); - instantiate_axpby(bfloat16, bfloat16_t); - instantiate_axpby(complex64, complex64_t); + instantiate_kernel("axpby_general_float32", axpby_general, float) + instantiate_kernel("axpby_general_float16", axpby_general, float16_t) + instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t) + instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t) The logic to determine the kernel, set the inputs, resolve the grid dimensions, and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown diff --git a/examples/extensions/axpby/axpby.metal b/examples/extensions/axpby/axpby.metal index 7c5f32689..bec5a5d53 100644 --- a/examples/extensions/axpby/axpby.metal +++ b/examples/extensions/axpby/axpby.metal @@ -12,8 +12,8 @@ template constant const float& alpha [[buffer(3)]], constant const float& beta [[buffer(4)]], constant const int* shape [[buffer(5)]], - constant const size_t* x_strides [[buffer(6)]], - constant const size_t* y_strides [[buffer(7)]], + constant const int64_t* x_strides [[buffer(6)]], + constant const int64_t* y_strides [[buffer(7)]], constant const int& ndim [[buffer(8)]], uint index [[thread_position_in_grid]]) { auto x_offset = elem_to_loc(index, shape, x_strides, ndim); @@ -34,29 +34,14 @@ template static_cast(alpha) * x[index] + static_cast(beta) * y[index]; } -#define instantiate_axpby(type_name, type) \ - template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \ - axpby_general( \ - device const type* x [[buffer(0)]], \ - device const type* y [[buffer(1)]], \ - device type* out [[buffer(2)]], \ - constant const float& alpha [[buffer(3)]], \ - constant const float& beta [[buffer(4)]], \ - constant const int* shape [[buffer(5)]], \ - constant const size_t* x_strides [[buffer(6)]], \ - constant const size_t* y_strides [[buffer(7)]], \ - constant const int& ndim [[buffer(8)]], \ - uint index [[thread_position_in_grid]]); \ - template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \ - axpby_contiguous( \ - device const type* x [[buffer(0)]], \ - device const type* y [[buffer(1)]], \ - device type* out [[buffer(2)]], \ - constant const float& alpha [[buffer(3)]], \ - constant const float& beta [[buffer(4)]], \ - uint index [[thread_position_in_grid]]); +// clang-format off +#define instantiate_axpby(type_name, type) \ + instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \ + instantiate_kernel( \ + "axpby_contiguous_" #type_name, axpby_contiguous, type) instantiate_axpby(float32, float); instantiate_axpby(float16, half); instantiate_axpby(bfloat16, bfloat16_t); instantiate_axpby(complex64, complex64_t); +// clang-format on diff --git a/mlx/array.h b/mlx/array.h index 8c1f7e933..d76a1c0e0 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -18,7 +18,7 @@ class Primitive; using Deleter = std::function; using Shape = std::vector; -using Strides = std::vector; +using Strides = std::vector; class array { /* An array is really a node in a graph. It contains a shared ArrayDesc diff --git a/mlx/backend/common/arg_reduce.cpp b/mlx/backend/common/arg_reduce.cpp index f4a672591..00f78136c 100644 --- a/mlx/backend/common/arg_reduce.cpp +++ b/mlx/backend/common/arg_reduce.cpp @@ -13,8 +13,8 @@ template void arg_reduce(const array& in, array& out, const OpT& op, int axis) { auto axis_size = in.shape()[axis]; auto axis_stride = in.strides()[axis]; - std::vector strides = in.strides(); - std::vector shape = in.shape(); + Strides strides = in.strides(); + Shape shape = in.shape(); strides.erase(strides.begin() + axis); shape.erase(shape.begin() + axis); for (uint32_t i = 0; i < out.size(); ++i) { diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index 3898e1d40..7b9d6ec02 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -178,10 +178,10 @@ void binary_op_dims( const T* b, U* out, Op op, - const std::vector& shape, - const std::vector& a_strides, - const std::vector& b_strides, - const std::vector& out_strides, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& out_strides, int axis) { auto stride_a = a_strides[axis]; auto stride_b = b_strides[axis]; @@ -212,10 +212,10 @@ void binary_op_dispatch_dims( array& out, Op op, int dim, - const std::vector& shape, - const std::vector& a_strides, - const std::vector& b_strides, - const std::vector& out_strides) { + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& out_strides) { const T* a_ptr = a.data(); const T* b_ptr = b.data(); U* out_ptr = out.data(); @@ -258,10 +258,10 @@ void binary_op_dispatch_dims( return; } - ContiguousIterator a_it(shape, a_strides, dim - 3); - ContiguousIterator b_it(shape, b_strides, dim - 3); - size_t stride = out_strides[dim - 4]; - for (size_t elem = 0; elem < a.size(); elem += stride) { + ContiguousIterator a_it(shape, a_strides, dim - 3); + ContiguousIterator b_it(shape, b_strides, dim - 3); + auto stride = out_strides[dim - 4]; + for (int64_t elem = 0; elem < a.size(); elem += stride) { binary_op_dims( a_ptr + a_it.loc, b_ptr + b_it.loc, @@ -327,7 +327,7 @@ void binary_op( const auto& strides = new_strides[2]; // Get the left-most dim such that the array is row contiguous after - auto leftmost_rc_dim = [&strides](const std::vector& arr_strides) { + auto leftmost_rc_dim = [&strides](const auto& arr_strides) { int d = arr_strides.size() - 1; for (; d >= 0 && arr_strides[d] == strides[d]; d--) { } @@ -337,7 +337,7 @@ void binary_op( auto b_rc_dim = leftmost_rc_dim(b_strides); // Get the left-most dim such that the array is a broadcasted "scalar" after - auto leftmost_s_dim = [](const std::vector& arr_strides) { + auto leftmost_s_dim = [](const auto& arr_strides) { int d = arr_strides.size() - 1; for (; d >= 0 && arr_strides[d] == 0; d--) { } diff --git a/mlx/backend/common/binary_two.h b/mlx/backend/common/binary_two.h index e9740f8aa..5088c06aa 100644 --- a/mlx/backend/common/binary_two.h +++ b/mlx/backend/common/binary_two.h @@ -16,10 +16,10 @@ void binary_op_dims( U* out_a, U* out_b, Op op, - const std::vector& shape, - const std::vector& a_strides, - const std::vector& b_strides, - const std::vector& out_strides, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& out_strides, int axis) { auto stride_a = a_strides[axis]; auto stride_b = b_strides[axis]; @@ -96,9 +96,9 @@ void binary_op_dispatch_dims( return; } - ContiguousIterator a_it(shape, a_strides, ndim - 2); - ContiguousIterator b_it(shape, b_strides, ndim - 2); - size_t stride = out_strides[ndim - 3]; + ContiguousIterator a_it(shape, a_strides, ndim - 2); + ContiguousIterator b_it(shape, b_strides, ndim - 2); + auto stride = out_strides[ndim - 3]; for (size_t elem = 0; elem < a.size(); elem += stride) { binary_op_dims( a_ptr + a_it.loc, diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index fba9dc15b..0a677a01b 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -49,7 +49,7 @@ void Broadcast::eval(const std::vector& inputs, array& out) { out.set_data(nullptr); return; } - std::vector strides(out.ndim(), 0); + Strides strides(out.ndim(), 0); int diff = out.ndim() - in.ndim(); for (int i = in.ndim() - 1; i >= 0; --i) { strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; @@ -141,7 +141,7 @@ void NumberOfElements::eval(const std::vector& inputs, array& out) { } } -std::pair> Reshape::prepare_reshape( +std::pair Reshape::prepare_reshape( const array& in, const array& out) { // Special case for empty arrays or row contiguous arrays @@ -151,8 +151,7 @@ std::pair> Reshape::prepare_reshape( // Special case for scalars if (in.ndim() == 0) { - std::vector out_strides(out.ndim(), 0); - return {false, out_strides}; + return {false, Strides(out.ndim(), 0)}; } // Firstly let's collapse all the contiguous dimensions of the input @@ -160,7 +159,7 @@ std::pair> Reshape::prepare_reshape( // If shapes fit exactly in the contiguous dims then no copy is necessary so // let's check. - std::vector out_strides; + Strides out_strides; bool copy_necessary = false; int j = 0; for (int i = 0; i < out.ndim(); i++) { @@ -183,7 +182,7 @@ std::pair> Reshape::prepare_reshape( void Reshape::shared_buffer_reshape( const array& in, - const std::vector& out_strides, + const Strides& out_strides, array& out) { auto flags = in.flags(); if (flags.row_contiguous) { @@ -249,18 +248,6 @@ void Split::eval( } } -std::tuple> SliceUpdate::prepare_slice( - const array& in) { - int64_t data_offset = 0; - std::vector inp_strides(in.ndim(), 0); - for (int i = 0; i < in.ndim(); ++i) { - data_offset += start_indices_[i] * in.strides()[i]; - inp_strides[i] = in.strides()[i] * strides_[i]; - } - - return std::make_tuple(data_offset, inp_strides); -} - void StopGradient::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); move_or_copy(inputs[0], out); @@ -268,7 +255,7 @@ void StopGradient::eval(const std::vector& inputs, array& out) { void Transpose::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - std::vector out_strides(out.ndim()); + Strides out_strides(out.ndim()); auto& in = inputs[0]; for (int ax = 0; ax < axes_.size(); ++ax) { out_strides[ax] = in.strides()[axes_[ax]]; @@ -285,8 +272,8 @@ void Transpose::eval(const std::vector& inputs, array& out) { // true, they stay true) auto flags = in.flags(); if (flags.contiguous && in.data_size() == in.size()) { - size_t f_stride = 1; - size_t b_stride = 1; + int64_t f_stride = 1; + int64_t b_stride = 1; flags.col_contiguous = true; flags.row_contiguous = true; for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) { diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index cf6cb39b3..4c782089a 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -165,7 +165,7 @@ void compiled_allocate_outputs( bool move_buffers /* = false */) { if (contiguous) { int o = 0; - std::vector strides; + Strides strides; size_t data_size; array::Flags flags; for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index 67bdaeefb..879c0312e 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -746,9 +746,9 @@ void explicit_gemm_conv_1D_cpu( copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); // Make strided view - std::vector strided_shape = {N, oH, wH, C}; + Shape strided_shape = {N, oH, wH, C}; - std::vector strided_strides = { + Strides strided_strides = { in_padded.strides()[0], in_padded.strides()[1] * wt_strides[0], in_padded.strides()[1], @@ -865,9 +865,9 @@ void explicit_gemm_conv_2D_cpu( copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); // Make strided view - std::vector strided_shape = {N, oH, oW, wH, wW, C}; + Shape strided_shape = {N, oH, oW, wH, wW, C}; - std::vector strided_strides = { + Strides strided_strides = { in_padded.strides()[0], in_padded.strides()[1] * wt_strides[0], in_padded.strides()[2] * wt_strides[1], @@ -974,7 +974,7 @@ void explicit_gemm_conv_ND_cpu( copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); // Make strided view - std::vector strided_shape(oDim.size() + wDim.size() + 2); + Shape strided_shape(oDim.size() + wDim.size() + 2); strided_shape.front() = N; for (size_t i = 0; i < oDim.size(); i++) { strided_shape[i + 1] = oDim[i]; @@ -984,7 +984,7 @@ void explicit_gemm_conv_ND_cpu( } strided_shape.back() = C; - std::vector strided_strides(in.shape().size() * 2 - 2); + Strides strided_strides(in.shape().size() * 2 - 2); strided_strides[0] = in_padded.strides()[0]; for (size_t i = 0; i < wt_strides.size(); i++) { strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i]; @@ -1000,7 +1000,7 @@ void explicit_gemm_conv_ND_cpu( in_padded, strided_strides, flags, in_strided_view.size(), 0); // Materialize strided view - std::vector strided_reshape = {N, C}; + Shape strided_reshape = {N, C}; for (const auto& o : oDim) { strided_reshape[0] *= o; } diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/common/copy.cpp index 31448e1c6..80cbc9f56 100644 --- a/mlx/backend/common/copy.cpp +++ b/mlx/backend/common/copy.cpp @@ -26,13 +26,13 @@ void copy_vector(const array& src, array& dst) { std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); } -template +template inline void copy_dims( const SrcT* src, DstT* dst, - const std::vector& shape, - const std::vector& i_strides, - const std::vector& o_strides, + const Shape& shape, + const Strides& i_strides, + const Strides& o_strides, int axis) { auto stride_src = i_strides[axis]; auto stride_dst = o_strides[axis]; @@ -40,7 +40,7 @@ inline void copy_dims( for (int i = 0; i < N; i++) { if constexpr (D > 1) { - copy_dims( + copy_dims( src, dst, shape, i_strides, o_strides, axis + 1); } else { *dst = static_cast(*src); @@ -50,13 +50,13 @@ inline void copy_dims( } } -template +template void copy_general_general( const array& src, array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - const std::vector& o_strides, + const Shape& data_shape, + const Strides& i_strides, + const Strides& o_strides, int64_t i_offset, int64_t o_offset) { if (data_shape.empty()) { @@ -65,30 +65,30 @@ void copy_general_general( *dst_ptr = val; return; } - auto [shape, strides] = collapse_contiguous_dims( - data_shape, std::vector>{i_strides, o_strides}); + auto [shape, strides] = + collapse_contiguous_dims(data_shape, {i_strides, o_strides}); auto src_ptr = src.data() + i_offset; auto dst_ptr = dst.data() + o_offset; int ndim = shape.size(); if (ndim == 1) { - copy_dims( + copy_dims( src_ptr, dst_ptr, shape, strides[0], strides[1], 0); return; } else if (ndim == 2) { - copy_dims( + copy_dims( src_ptr, dst_ptr, shape, strides[0], strides[1], 0); return; } else if (ndim == 3) { - copy_dims( + copy_dims( src_ptr, dst_ptr, shape, strides[0], strides[1], 0); return; } - ContiguousIterator in(shape, strides[0], ndim - 3); - ContiguousIterator out(shape, strides[1], ndim - 3); - StrideT stride = std::accumulate( - shape.end() - 3, shape.end(), 1, std::multiplies()); - for (StrideT elem = 0; elem < src.size(); elem += stride) { - copy_dims( + ContiguousIterator in(shape, strides[0], ndim - 3); + ContiguousIterator out(shape, strides[1], ndim - 3); + auto stride = std::accumulate( + shape.end() - 3, shape.end(), 1, std::multiplies()); + for (int64_t elem = 0; elem < src.size(); elem += stride) { + copy_dims( src_ptr + in.loc, dst_ptr + out.loc, shape, @@ -102,37 +102,37 @@ void copy_general_general( template inline void copy_general_general(const array& src, array& dst) { - copy_general_general( + copy_general_general( src, dst, src.shape(), src.strides(), dst.strides(), 0, 0); } -template +template void copy_general( const array& src, array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - const std::vector&, + const Shape& data_shape, + const Strides& i_strides, + const Strides&, int64_t i_offset, int64_t o_offset) { - copy_general_general( + copy_general_general( src, dst, data_shape, i_strides, - make_contiguous_strides(data_shape), + make_contiguous_strides(data_shape), i_offset, o_offset); } template inline void copy_general(const array& src, array& dst) { - copy_general_general( + copy_general_general( src, dst, src.shape(), src.strides(), - make_contiguous_strides(src.shape()), + make_contiguous_strides(src.shape()), 0, 0); } @@ -282,13 +282,12 @@ void copy(const array& src, array& dst, CopyType ctype) { copy_inplace(src, dst, ctype); } -template void copy_inplace( const array& src, array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - const std::vector& o_strides, + const Shape& data_shape, + const Strides& i_strides, + const Strides& o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype) { @@ -311,24 +310,4 @@ void copy_inplace( } } -template void copy_inplace( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - const std::vector& o_strides, - int64_t i_offset, - int64_t o_offset, - CopyType ctype); - -template void copy_inplace( - const array& src, - array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - const std::vector& o_strides, - int64_t i_offset, - int64_t o_offset, - CopyType ctype); - } // namespace mlx::core diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h index b0106257a..351790c02 100644 --- a/mlx/backend/common/copy.h +++ b/mlx/backend/common/copy.h @@ -26,13 +26,12 @@ enum class CopyType { void copy(const array& src, array& dst, CopyType ctype); void copy_inplace(const array& src, array& dst, CopyType ctype); -template void copy_inplace( const array& src, array& dst, - const std::vector& data_shape, - const std::vector& i_strides, - const std::vector& o_strides, + const Shape& data_shape, + const Strides& i_strides, + const Strides& o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype); diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 547d8e25d..3313ac0e1 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -130,7 +130,7 @@ inline void matmul_common_general( } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy(arr, arr_copy, CopyType::General); - size_t stx = arr.shape(-1); + stx = arr.shape(-1); return std::make_tuple(false, stx, arr_copy); } }; diff --git a/mlx/backend/common/indexing.cpp b/mlx/backend/common/indexing.cpp index 1bb3eb44f..9519bf891 100644 --- a/mlx/backend/common/indexing.cpp +++ b/mlx/backend/common/indexing.cpp @@ -32,7 +32,7 @@ void gather( const std::vector& inds, array& out, const std::vector& axes, - const std::vector& slice_sizes) { + const Shape& slice_sizes) { // If the array is row contiguous then we can do a contiguous copy given // two conditions on the slice size: // - Any number of leading ones in the slice sizes are allowed @@ -80,11 +80,10 @@ void gather( T* dst_ptr = out.data(); size_t out_idx = 0; - std::vector> its(inds.begin(), inds.end()); - ContiguousIterator src_it; + std::vector its(inds.begin(), inds.end()); + ContiguousIterator src_it; if (!can_copy && src.ndim() > 0) { - src_it = std::move( - ContiguousIterator(slice_sizes, src.strides(), src.ndim())); + src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim()); } for (int idx = 0; idx < ind_size; idx++) { size_t src_idx = 0; @@ -119,7 +118,7 @@ void dispatch_gather( const std::vector& inds, array& out, const std::vector& axes, - const std::vector& size) { + const Shape& size) { switch (out.dtype()) { case bool_: gather(src, inds, out, axes, size); @@ -223,16 +222,16 @@ void scatter( auto inds_ndim = updates.ndim() - out.ndim(); size_t n_updates = nind ? inds[0].size() : 1; - std::vector update_shape( + Shape update_shape( updates.shape().begin() + inds_ndim, updates.shape().end()); size_t update_size = 1; for (auto us : update_shape) { update_size *= us; } - std::vector> its(inds.begin(), inds.end()); - ContiguousIterator update_it(updates); - ContiguousIterator out_it(update_shape, out.strides(), out.ndim()); + std::vector its(inds.begin(), inds.end()); + ContiguousIterator update_it(updates); + ContiguousIterator out_it(update_shape, out.strides(), out.ndim()); for (int i = 0; i < n_updates; ++i) { size_t out_offset = 0; diff --git a/mlx/backend/common/masked_mm.cpp b/mlx/backend/common/masked_mm.cpp index d0286f0fd..f6c8300f2 100644 --- a/mlx/backend/common/masked_mm.cpp +++ b/mlx/backend/common/masked_mm.cpp @@ -19,10 +19,10 @@ inline void mask_matrix( int block_size, const int X, const int Y, - const size_t X_data_str, - const size_t Y_data_str, - const size_t X_mask_str, - const size_t Y_mask_str, + const int64_t X_data_str, + const int64_t Y_data_str, + const int64_t X_mask_str, + const int64_t Y_mask_str, const size_t mask_offset) { int tX = (X + block_size - 1) / block_size; int tY = (Y + block_size - 1) / block_size; @@ -84,7 +84,7 @@ void BlockMaskedMM::eval(const std::vector& inputs, array& out) { } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy(arr, arr_copy, CopyType::General); - size_t stx = arr.shape(-1); + int64_t stx = arr.shape(-1); return std::make_tuple(false, stx, arr_copy); } }; @@ -117,13 +117,13 @@ void BlockMaskedMM::eval(const std::vector& inputs, array& out) { int Y, size_t X_data_str, size_t Y_data_str) { - size_t mask_offset = elem_to_loc( + auto mask_offset = elem_to_loc( mask.shape(-1) * mask.shape(-2) * batch_idx, mask.shape(), mask.strides()); - size_t X_mask_str = mask.strides()[mask.ndim() - 2]; - size_t Y_mask_str = mask.strides()[mask.ndim() - 1]; + auto X_mask_str = mask.strides()[mask.ndim() - 2]; + auto Y_mask_str = mask.strides()[mask.ndim() - 1]; if (mask.dtype() == bool_) { return mask_matrix( @@ -230,7 +230,7 @@ void GatherMM::eval(const std::vector& inputs, array& out) { } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); copy(arr, arr_copy, CopyType::General); - size_t stx = arr.shape(-1); + int64_t stx = arr.shape(-1); return std::make_tuple(false, stx, arr_copy); } }; @@ -262,13 +262,13 @@ void GatherMM::eval(const std::vector& inputs, array& out) { auto& lhs_indices = inputs[2]; auto& rhs_indices = inputs[3]; - std::vector batch_shape = get_batch_dims(out.shape()); + auto batch_shape = get_batch_dims(out.shape()); int batch_ndim = batch_shape.size(); - std::vector batch_shape_A = get_batch_dims(a.shape()); - std::vector batch_strides_A = get_batch_dims(a.strides()); - std::vector batch_shape_B = get_batch_dims(b.shape()); - std::vector batch_strides_B = get_batch_dims(b.strides()); + auto batch_shape_A = get_batch_dims(a.shape()); + auto batch_strides_A = get_batch_dims(a.strides()); + auto batch_shape_B = get_batch_dims(b.shape()); + auto batch_strides_B = get_batch_dims(b.strides()); const uint32_t* lhs_indices_ptr = lhs_indices.data(); const uint32_t* rhs_indices_ptr = rhs_indices.data(); diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 00338ef88..12042ed0f 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -498,14 +498,15 @@ void Slice::eval(const std::vector& inputs, array& out) { auto& in = inputs[0]; // Calculate out strides, initial offset and if copy needs to be made - auto [copy_needed, data_offset, inp_strides] = - prepare_slice(in, start_indices_, strides_); + auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_); + auto copy_needed = std::any_of( + strides_.begin(), strides_.end(), [](auto i) { return i < 0; }); // Do copy if needed if (copy_needed) { out.set_data(allocator::malloc_or_wait(out.nbytes())); - std::vector ostrides{out.strides().begin(), out.strides().end()}; - copy_inplace( + Strides ostrides{out.strides().begin(), out.strides().end()}; + copy_inplace( /* const array& src = */ in, /* array& dst = */ out, /* const std::vector& data_shape = */ out.shape(), @@ -523,7 +524,7 @@ void Slice::eval(const std::vector& inputs, array& out) { } } size_t data_size = data_end - data_offset; - std::vector ostrides{inp_strides.begin(), inp_strides.end()}; + Strides ostrides{inp_strides.begin(), inp_strides.end()}; shared_buffer_slice(in, ostrides, data_offset, data_size, out); } } @@ -550,11 +551,11 @@ void SliceUpdate::eval(const std::vector& inputs, array& out) { copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype); // Calculate out strides, initial offset and if copy needs to be made - auto [data_offset, out_strides] = prepare_slice(out); + auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_); // Do copy - std::vector upd_strides{upd.strides().begin(), upd.strides().end()}; - copy_inplace( + Strides upd_strides{upd.strides().begin(), upd.strides().end()}; + copy_inplace( /* const array& src = */ upd, /* array& dst = */ out, /* const std::vector& data_shape = */ upd.shape(), diff --git a/mlx/backend/common/qrf.cpp b/mlx/backend/common/qrf.cpp index 9383f6c88..c1f123aaf 100644 --- a/mlx/backend/common/qrf.cpp +++ b/mlx/backend/common/qrf.cpp @@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r) { // Copy the input to be column contiguous flags.col_contiguous = num_matrices == 1; flags.row_contiguous = false; - std::vector strides = in.strides(); + auto strides = in.strides(); strides[in.ndim() - 2] = 1; strides[in.ndim() - 1] = M; in.set_data( diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp index 049bb7409..332ee7169 100644 --- a/mlx/backend/common/reduce.cpp +++ b/mlx/backend/common/reduce.cpp @@ -174,19 +174,19 @@ void reduce_dispatch_min_max( void nd_loop( std::function callback, - const std::vector& shape, - const std::vector& strides) { + const Shape& shape, + const Strides& strides) { std::function loop_inner; loop_inner = [&](int dim, int offset) { if (dim < shape.size() - 1) { - int size = shape[dim]; - size_t stride = strides[dim]; + auto size = shape[dim]; + auto stride = strides[dim]; for (int i = 0; i < size; i++) { loop_inner(dim + 1, offset + i * stride); } } else { - int size = shape[dim]; - size_t stride = strides[dim]; + auto size = shape[dim]; + auto stride = strides[dim]; for (int i = 0; i < size; i++) { callback(offset + i * stride); } diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index 1815d65bf..35d8f9e48 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -38,13 +38,10 @@ enum ReductionOpType { struct ReductionPlan { ReductionOpType type; - std::vector shape; - std::vector strides; + Shape shape; + Strides strides; - ReductionPlan( - ReductionOpType type_, - std::vector shape_, - std::vector strides_) + ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_) : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {} ReductionPlan(ReductionOpType type_) : type(type_) {} }; @@ -55,10 +52,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector& axes); // Should this be in utils? void nd_loop( std::function callback, - const std::vector& shape, - const std::vector& strides); + const Shape& shape, + const Strides& strides); -std::pair, std::vector> shapes_without_reduction_axes( +std::pair shapes_without_reduction_axes( const array& x, const std::vector& axes); @@ -113,9 +110,6 @@ void reduction_op( return; } - std::vector shape; - std::vector strides; - if (plan.type == ContiguousReduce && plan.shape.size() == 1) { int reduction_size = plan.shape[0]; const T* x_ptr = x.data(); @@ -135,7 +129,7 @@ void reduction_op( U* out_ptr = out.data(); // Unrolling the following loop (and implementing it in order for // ContiguousReduce) should hold extra performance boost. - std::tie(shape, strides) = shapes_without_reduction_axes(x, axes); + auto [shape, strides] = shapes_without_reduction_axes(x, axes); if (plan.shape.size() == 0) { for (int i = 0; i < out.size(); i++, out_ptr++) { int offset = elem_to_loc(i, shape, strides); @@ -181,7 +175,7 @@ void reduction_op( plan.strides.pop_back(); const T* x_ptr = x.data(); U* out_ptr = out.data(); - std::tie(shape, strides) = shapes_without_reduction_axes(x, axes); + auto [shape, strides] = shapes_without_reduction_axes(x, axes); if (plan.shape.size() == 0) { for (int i = 0; i < out.size(); i += reduction_stride) { int offset = elem_to_loc(i, shape, strides); @@ -211,7 +205,7 @@ void reduction_op( if (plan.type == GeneralReduce) { const T* x_ptr = x.data(); U* out_ptr = out.data(); - std::tie(shape, strides) = shapes_without_reduction_axes(x, axes); + auto [shape, strides] = shapes_without_reduction_axes(x, axes); for (int i = 0; i < out.size(); i++, out_ptr++) { int offset = elem_to_loc(i, shape, strides); U val = init; diff --git a/mlx/backend/common/reduce_utils.cpp b/mlx/backend/common/reduce_utils.cpp index 64049873e..5c7f63b75 100644 --- a/mlx/backend/common/reduce_utils.cpp +++ b/mlx/backend/common/reduce_utils.cpp @@ -4,11 +4,11 @@ namespace mlx::core { -std::pair, std::vector> shapes_without_reduction_axes( +std::pair shapes_without_reduction_axes( const array& x, const std::vector& axes) { - std::vector shape = x.shape(); - std::vector strides = x.strides(); + auto shape = x.shape(); + auto strides = x.strides(); for (int i = axes.size() - 1; i >= 0; i--) { int a = axes[i]; @@ -29,8 +29,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { // Row contiguous input so the output is row contiguous if (x.flags().row_contiguous) { // Merge consecutive axes - std::vector shape = {x.shape(axes[0])}; - std::vector strides = {x.strides()[axes[0]]}; + Shape shape = {x.shape(axes[0])}; + Strides strides = {x.strides()[axes[0]]}; for (int i = 1; i < axes.size(); i++) { if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) { shape.back() *= x.shape(axes[i]); @@ -69,7 +69,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { // Sort reduction axes by stride in order to merge them and figure out if we // have a contiguous reduction. - std::vector> reductions; + std::vector> reductions; for (auto a : axes) { if (x.shape(a) > 1) { reductions.push_back(std::make_pair(x.shape(a), x.strides()[a])); @@ -93,8 +93,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { } } - std::vector shape; - std::vector strides; + Shape shape; + Strides strides; for (auto r : reductions) { shape.push_back(r.first); strides.push_back(r.second); @@ -109,15 +109,15 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector& axes) { // Delegate to the general strided reduction op if the axes after // strides.back() are contiguous. if (strides.back() > 1) { - int size = 1; + int64_t size = 1; bool have_expand = false; for (int i = x.ndim() - 1; i >= 0; i--) { if (axes.back() == i) { continue; } - size_t stride_i = x.strides()[i]; - int shape_i = x.shape(i); + auto stride_i = x.strides()[i]; + auto shape_i = x.shape(i); if (stride_i == 0) { if (shape_i == 1) { continue; diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp index 343f0ff57..9a51aefaf 100644 --- a/mlx/backend/common/slicing.cpp +++ b/mlx/backend/common/slicing.cpp @@ -4,24 +4,22 @@ namespace mlx::core { -std::tuple> prepare_slice( +std::tuple prepare_slice( const array& in, - const std::vector& start_indices, - const std::vector& strides) { + const Shape& start_indices, + const Shape& strides) { int64_t data_offset = 0; - bool copy_needed = false; - std::vector inp_strides(in.ndim(), 0); + Strides inp_strides(in.ndim(), 0); for (int i = 0; i < in.ndim(); ++i) { data_offset += start_indices[i] * in.strides()[i]; inp_strides[i] = in.strides()[i] * strides[i]; - copy_needed |= strides[i] < 0; } - return std::make_tuple(copy_needed, data_offset, inp_strides); + return std::make_tuple(data_offset, inp_strides); } void shared_buffer_slice( const array& in, - const std::vector& out_strides, + const Strides& out_strides, size_t data_offset, size_t data_size, array& out) { diff --git a/mlx/backend/common/slicing.h b/mlx/backend/common/slicing.h index 9ee8216f4..eda37320d 100644 --- a/mlx/backend/common/slicing.h +++ b/mlx/backend/common/slicing.h @@ -6,14 +6,14 @@ namespace mlx::core { -std::tuple> prepare_slice( +std::tuple prepare_slice( const array& in, - const std::vector& start_indices, - const std::vector& strides); + const Shape& start_indices, + const Shape& strides); void shared_buffer_slice( const array& in, - const std::vector& out_strides, + const Strides& out_strides, size_t data_offset, size_t data_size, array& out); diff --git a/mlx/backend/common/sort.cpp b/mlx/backend/common/sort.cpp index 1d3d80218..29e4d9d5e 100644 --- a/mlx/backend/common/sort.cpp +++ b/mlx/backend/common/sort.cpp @@ -25,7 +25,7 @@ struct StridedIterator { // Constructors StridedIterator() = default; - explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0) + explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0) : ptr_(ptr + offset * stride), stride_(stride) {} explicit StridedIterator(array& arr, int axis, difference_type offset = 0) @@ -99,7 +99,7 @@ struct StridedIterator { } private: - size_t stride_; + int64_t stride_; T* ptr_; }; @@ -120,11 +120,11 @@ void sort(const array& in, array& out, int axis) { auto remaining_strides = out.strides(); remaining_strides.erase(remaining_strides.begin() + axis); - size_t axis_stride = out.strides()[axis]; - int axis_size = out.shape(axis); + auto axis_stride = out.strides()[axis]; + auto axis_size = out.shape(axis); // Perform sorting in place - ContiguousIterator src_it( + ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); for (int i = 0; i < n_rows; i++) { T* data_ptr = out.data() + src_it.loc; @@ -158,14 +158,14 @@ void argsort(const array& in, array& out, int axis) { auto out_remaining_strides = out.strides(); out_remaining_strides.erase(out_remaining_strides.begin() + axis); - size_t in_stride = in.strides()[axis]; - size_t out_stride = out.strides()[axis]; - int axis_size = in.shape(axis); + auto in_stride = in.strides()[axis]; + auto out_stride = out.strides()[axis]; + auto axis_size = in.shape(axis); // Perform sorting - ContiguousIterator in_it( + ContiguousIterator in_it( in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); - ContiguousIterator out_it( + ContiguousIterator out_it( out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); for (int i = 0; i < n_rows; i++) { const T* data_ptr = in.data() + in_it.loc; @@ -208,13 +208,13 @@ void partition(const array& in, array& out, int axis, int kth) { auto remaining_strides = in.strides(); remaining_strides.erase(remaining_strides.begin() + axis); - size_t axis_stride = in.strides()[axis]; + auto axis_stride = in.strides()[axis]; int axis_size = in.shape(axis); kth = kth < 0 ? kth + axis_size : kth; // Perform partition in place - ContiguousIterator src_it( + ContiguousIterator src_it( remaining_shape, remaining_strides, remaining_shape.size()); for (int i = 0; i < n_rows; i++) { T* data_ptr = out.data() + src_it.loc; @@ -249,16 +249,16 @@ void argpartition(const array& in, array& out, int axis, int kth) { auto out_remaining_strides = out.strides(); out_remaining_strides.erase(out_remaining_strides.begin() + axis); - size_t in_stride = in.strides()[axis]; - size_t out_stride = out.strides()[axis]; - int axis_size = in.shape(axis); + auto in_stride = in.strides()[axis]; + auto out_stride = out.strides()[axis]; + auto axis_size = in.shape(axis); kth = kth < 0 ? kth + axis_size : kth; // Perform partition - ContiguousIterator in_it( + ContiguousIterator in_it( in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); - ContiguousIterator out_it( + ContiguousIterator out_it( out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); for (int i = 0; i < n_rows; i++) { const T* data_ptr = in.data() + in_it.loc; diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index dcd5a8676..050e07e02 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -78,11 +78,11 @@ void ternary_op_dims( const T3* c, U* out, Op op, - const std::vector& shape, - const std::vector& a_strides, - const std::vector& b_strides, - const std::vector& c_strides, - const std::vector& out_strides, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& c_strides, + const Strides& out_strides, int axis) { auto stride_a = a_strides[axis]; auto stride_b = b_strides[axis]; @@ -164,10 +164,10 @@ void ternary_op_dispatch_dims( return; } - ContiguousIterator a_it(shape, a_strides, ndim - 2); - ContiguousIterator b_it(shape, b_strides, ndim - 2); - ContiguousIterator c_it(shape, c_strides, ndim - 2); - size_t stride = out_strides[ndim - 3]; + ContiguousIterator a_it(shape, a_strides, ndim - 2); + ContiguousIterator b_it(shape, b_strides, ndim - 2); + ContiguousIterator c_it(shape, c_strides, ndim - 2); + auto stride = out_strides[ndim - 3]; for (size_t elem = 0; elem < a.size(); elem += stride) { ternary_op_dims( a_ptr + a_it.loc, diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 97fdfe968..1179ee9be 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -15,7 +15,7 @@ void move_or_copy(const array& in, array& out) { void move_or_copy( const array& in, array& out, - const std::vector& strides, + const Strides& strides, array::Flags flags, size_t data_size, size_t offset /* = 0 */) { @@ -26,15 +26,13 @@ void move_or_copy( } } -template -std::tuple, std::vector>> -collapse_contiguous_dims_impl( - const std::vector& shape, - const std::vector>& strides, - StrideT size_cap) { +std::tuple> collapse_contiguous_dims( + const Shape& shape, + const std::vector& strides, + int64_t size_cap) { // Make a vector that has axes separated with -1. Collapse all axes between // -1. - std::vector to_collapse; + Shape to_collapse; if (shape.size() > 0) { if (shape[0] != 1) { to_collapse.push_back(0); @@ -43,7 +41,7 @@ collapse_contiguous_dims_impl( for (int i = 1; i < shape.size(); i++) { bool contiguous = true; size *= shape[i]; - for (const std::vector& st : strides) { + for (const auto& st : strides) { if (st[i] * shape[i] != st[i - 1] || size > size_cap) { contiguous = false; size = shape[i]; @@ -60,8 +58,8 @@ collapse_contiguous_dims_impl( to_collapse.push_back(-1); } - std::vector out_shape; - std::vector> out_strides(strides.size()); + Shape out_shape; + std::vector out_strides(strides.size()); for (int i = 0;;) { while (i < to_collapse.size() && to_collapse[i] == -1) { ++i; @@ -76,7 +74,7 @@ collapse_contiguous_dims_impl( } out_shape.push_back(current_shape); for (int j = 0; j < strides.size(); j++) { - const std::vector& st = strides[j]; + const auto& st = strides[j]; out_strides[j].push_back(st[to_collapse[k - 1]]); } i = k + 1; @@ -91,29 +89,12 @@ collapse_contiguous_dims_impl( return std::make_tuple(out_shape, out_strides); } -std::tuple, std::vector>> -collapse_contiguous_dims( - const std::vector& shape, - const std::vector>& strides, - int64_t size_cap /* = std::numeric_limits::max() */) { - return collapse_contiguous_dims_impl(shape, strides, size_cap); -} - -std::tuple, std::vector>> -collapse_contiguous_dims( - const std::vector& shape, - const std::vector>& strides, - size_t size_cap /* = std::numeric_limits::max() */) { - return collapse_contiguous_dims_impl(shape, strides, size_cap); -} - -template -std::pair, std::vector> collapse_contiguous_dims_impl( - const std::vector& shape, - const std::vector& strides, - StrideT size_cap) { - std::vector collapsed_shape; - std::vector collapsed_strides; +std::pair collapse_contiguous_dims( + const Shape& shape, + const Strides& strides, + int64_t size_cap) { + Shape collapsed_shape; + Strides collapsed_strides; if (shape.size() > 0) { collapsed_shape.push_back(shape[0]); @@ -123,7 +104,7 @@ std::pair, std::vector> collapse_contiguous_dims_impl( continue; } else if ( strides[i] * shape[i] != collapsed_strides.back() || - collapsed_shape.back() * static_cast(shape[i]) > size_cap) { + collapsed_shape.back() * static_cast(shape[i]) > size_cap) { collapsed_shape.push_back(shape[i]); collapsed_strides.push_back(strides[i]); } else { @@ -136,25 +117,10 @@ std::pair, std::vector> collapse_contiguous_dims_impl( return std::make_pair(collapsed_shape, collapsed_strides); } -std::pair, std::vector> collapse_contiguous_dims( - const std::vector& shape, - const std::vector& strides, - int64_t size_cap /* = std::numeric_limits::max() */) { - return collapse_contiguous_dims_impl(shape, strides, size_cap); -} - -std::pair, std::vector> collapse_contiguous_dims( - const std::vector& shape, - const std::vector& strides, - size_t size_cap /* = std::numeric_limits::max() */) { - return collapse_contiguous_dims_impl(shape, strides, size_cap); -} - -std::pair, std::vector> collapse_contiguous_dims( +std::pair collapse_contiguous_dims( const array& a, - size_t size_cap /* = std::numeric_limits::max()*/) { - return collapse_contiguous_dims_impl( - a.shape(), a.strides(), size_cap); + int64_t size_cap /* = std::numeric_limits::max()*/) { + return collapse_contiguous_dims(a.shape(), a.strides(), size_cap); } } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 3d466ed51..c67189b5d 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -8,12 +8,9 @@ namespace mlx::core { -template -inline StrideT elem_to_loc( - int elem, - const std::vector& shape, - const std::vector& strides) { - StrideT loc = 0; +inline int64_t +elem_to_loc(int elem, const Shape& shape, const Strides& strides) { + int64_t loc = 0; for (int i = shape.size() - 1; i >= 0; --i) { auto q_and_r = ldiv(elem, shape[i]); loc += q_and_r.rem * strides[i]; @@ -22,16 +19,15 @@ inline StrideT elem_to_loc( return loc; } -inline size_t elem_to_loc(int elem, const array& a) { +inline int64_t elem_to_loc(int elem, const array& a) { if (a.flags().row_contiguous) { return elem; } return elem_to_loc(elem, a.shape(), a.strides()); } -template -std::vector make_contiguous_strides(const std::vector& shape) { - std::vector strides(shape.size(), 1); +inline Strides make_contiguous_strides(const Shape& shape) { + Strides strides(shape.size(), 1); for (int i = shape.size() - 1; i > 0; i--) { strides[i - 1] = strides[i] * shape[i]; } @@ -44,22 +40,15 @@ std::vector make_contiguous_strides(const std::vector& shape) { // // When multiple arrays are passed they should all have the same shape. The // collapsed axes are also the same so one shape is returned. -std::tuple, std::vector>> -collapse_contiguous_dims( - const std::vector& shape, - const std::vector>& strides, +std::tuple> collapse_contiguous_dims( + const Shape& shape, + const std::vector& strides, int64_t size_cap = std::numeric_limits::max()); -std::tuple, std::vector>> -collapse_contiguous_dims( - const std::vector& shape, - const std::vector>& strides, - size_t size_cap = std::numeric_limits::max()); - -inline std::tuple, std::vector>> -collapse_contiguous_dims( + +inline std::tuple> collapse_contiguous_dims( const std::vector& xs, size_t size_cap = std::numeric_limits::max()) { - std::vector> strides; + std::vector strides; for (auto& x : xs) { strides.emplace_back(x.strides()); } @@ -73,19 +62,14 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) { } // The single array version of the above. -std::pair, std::vector> collapse_contiguous_dims( - const std::vector& shape, - const std::vector& strides, +std::pair collapse_contiguous_dims( + const Shape& shape, + const Strides& strides, int64_t size_cap = std::numeric_limits::max()); -std::pair, std::vector> collapse_contiguous_dims( - const std::vector& shape, - const std::vector& strides, - size_t size_cap = std::numeric_limits::max()); -std::pair, std::vector> collapse_contiguous_dims( +std::pair collapse_contiguous_dims( const array& a, - size_t size_cap = std::numeric_limits::max()); + int64_t size_cap = std::numeric_limits::max()); -template struct ContiguousIterator { inline void step() { int dims = shape_.size(); @@ -102,7 +86,7 @@ struct ContiguousIterator { loc += strides_[i]; } - void seek(StrideT n) { + void seek(int64_t n) { loc = 0; for (int i = shape_.size() - 1; i >= 0; --i) { auto q_and_r = ldiv(n, shape_[i]); @@ -128,32 +112,29 @@ struct ContiguousIterator { } explicit ContiguousIterator( - const std::vector& shape, - const std::vector& strides, + const Shape& shape, + const Strides& strides, int dims) : shape_(shape.begin(), shape.begin() + dims), strides_(strides.begin(), strides.begin() + dims) { if (!shape_.empty()) { std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); - pos_ = std::vector(shape_.size(), 0); + pos_ = Shape(shape_.size(), 0); } } - StrideT loc{0}; + int64_t loc{0}; private: - std::vector shape_; - std::vector strides_; - std::vector pos_; + Shape shape_; + Strides strides_; + Shape pos_; }; -template -inline auto check_contiguity( - const std::vector& shape, - const std::vector& strides) { +inline auto check_contiguity(const Shape& shape, const Strides& strides) { size_t no_broadcast_data_size = 1; - size_t f_stride = 1; - size_t b_stride = 1; + int64_t f_stride = 1; + int64_t b_stride = 1; bool is_row_contiguous = true; bool is_col_contiguous = true; @@ -182,7 +163,7 @@ void move_or_copy(const array& in, array& out); void move_or_copy( const array& in, array& out, - const std::vector& strides, + const Strides& strides, array::Flags flags, size_t data_size, size_t offset = 0); diff --git a/mlx/backend/metal/binary.cpp b/mlx/backend/metal/binary.cpp index 8e3015790..aeb1df354 100644 --- a/mlx/backend/metal/binary.cpp +++ b/mlx/backend/metal/binary.cpp @@ -75,8 +75,8 @@ void binary_op_gpu_inplace( auto [shape, strides] = collapse_contiguous_dims(a, b, out); return std::make_tuple(shape, strides[0], strides[1], strides[2]); } else { - std::vector e; - return std::make_tuple(std::vector{}, e, e, e); + decltype(a.strides()) e{}; + return std::make_tuple(decltype(a.shape()){}, e, e, e); } }; auto [shape, strides_a, strides_b, strides_out] = maybe_collapse(); diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index ffe5af6e5..b75044f1a 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -67,7 +67,7 @@ inline void build_kernel( if (add_indices) { os += fmt::format( - " constant const size_t* in_strides [[buffer({0})]],\n", cnt++); + " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++); } // Add the output arguments @@ -81,7 +81,7 @@ inline void build_kernel( // Add output strides and shape to extract the indices. if (!contiguous) { os += fmt::format( - " constant const size_t* output_strides [[buffer({0})]],\n", cnt++); + " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++); os += fmt::format( " constant const int* output_shape [[buffer({0})]],\n", cnt++); } @@ -93,11 +93,11 @@ inline void build_kernel( os += " uint3 pos [[thread_position_in_grid]],\n"; os += " uint3 grid [[threads_per_grid]]) {\n"; - std::string idx_type = use_big_index ? "size_t" : "uint"; + std::string idx_type = use_big_index ? "int64_t" : "uint"; if (contiguous && use_big_index) { // This is only used for contiguous kernels which don't have // a third grid dimension - os += " size_t index = pos.x + grid.x * size_t(pos.y);\n"; + os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n"; } else if (work_per_thread > 1) { os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread); os += fmt::format( @@ -144,20 +144,18 @@ inline void build_kernel( os += fmt::format(" {0} index_{1} = ", idx_type, xname); if (ndim == 1) { int offset = i * ndim; - os += fmt::format( - "elem_to_loc_1(pos.x, in_strides[{0}]);\n", offset); + os += + fmt::format("elem_to_loc_1(pos.x, in_strides[{0}]);\n", offset); } else if (ndim == 2) { int offset = i * ndim; os += fmt::format( - "elem_to_loc_2({{pos.x, pos.y}}, in_strides + {1});\n", + "elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n", idx_type, offset); } else if (ndim == 3) { int offset = i * ndim; os += fmt::format( - "elem_to_loc_3(pos, in_strides + {1});\n", - idx_type, - offset); + "elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset); } else if (!dynamic_dims) { int offset = (i + 1) * ndim; os += fmt::format( @@ -360,10 +358,10 @@ void Compiled::eval_gpu( // Collapse contiguous dims to route to a faster kernel if possible. Also // handle all broadcasting. - std::vector> initial_strides; + std::vector initial_strides; initial_strides.push_back(outputs[0].strides()); - std::vector shape; - std::vector> strides; + Shape shape; + std::vector strides; if (!contiguous) { for (int i = 0; i < inputs.size(); i++) { // Skip constants. @@ -378,7 +376,7 @@ void Compiled::eval_gpu( } // Broadcast the inputs to the output shape. - std::vector xstrides; + Strides xstrides; int j = 0; for (; j < output_shape.size() - x.ndim(); j++) { if (output_shape[j] == 1) { @@ -440,7 +438,7 @@ void Compiled::eval_gpu( // Put the inputs in int cnt = 0; int stride_idx = 1; // idx 0 is the output strides - std::vector in_strides; + Strides in_strides; for (int i = 0; i < inputs.size(); i++) { if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { continue; diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index fc1649730..554d280a9 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -64,8 +64,8 @@ void explicit_gemm_conv_ND_gpu( compute_encoder.dispatch_threads(grid_dims, group_dims); // Reshape weight - std::vector wt_reshape{implicit_K, implicit_N}; - std::vector wt_restride{1, static_cast(implicit_K)}; + Shape wt_reshape{implicit_K, implicit_N}; + Strides wt_restride{1, implicit_K}; array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {}); auto wt_flags = wt.flags(); wt_flags.row_contiguous = false; @@ -147,10 +147,7 @@ void explicit_gemm_conv_group_ND_gpu( array wt_view( {wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {}); wt_view.copy_shared_buffer( - wt, - {wt.strides(0), 1, static_cast(C_per_group)}, - wt.flags(), - wt.size()); + wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size()); // Materialize auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {}); diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 60d63a5ac..f808c52ed 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -43,13 +43,12 @@ void copy_gpu(const array& in, array& out, CopyType ctype) { copy_gpu(in, out, ctype, out.primitive().stream()); } -template void copy_gpu_inplace( const array& in, array& out, - const std::vector& data_shape, - const std::vector& strides_in_pre, - const std::vector& strides_out_pre, + const Shape& data_shape, + const Strides& strides_in_pre, + const Strides& strides_out_pre, int64_t inp_offset, int64_t out_offset, CopyType ctype, @@ -68,8 +67,8 @@ void copy_gpu_inplace( /* size_cap = */ INT32_MAX); return std::make_tuple(shape, strides[0], strides[1]); } else { - std::vector e; - return std::make_tuple(std::vector{}, e, e); + Strides e{}; + return std::make_tuple(Shape{}, e, e); } }; auto [shape, strides_in_, strides_out_] = maybe_collapse(); @@ -124,8 +123,8 @@ void copy_gpu_inplace( auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { - std::vector strides_in{strides_in_.begin(), strides_in_.end()}; - std::vector strides_out{strides_out_.begin(), strides_out_.end()}; + Strides strides_in{strides_in_.begin(), strides_in_.end()}; + Strides strides_out{strides_out_.begin(), strides_out_.end()}; if (ndim > 3) { compute_encoder.set_vector_bytes(shape, ndim, 2); } @@ -180,14 +179,13 @@ void copy_gpu_inplace( void copy_gpu_inplace( const array& in, array& out, - const std::vector& istride, + const Strides& istride, int64_t ioffset, CopyType ctype, const Stream& s) { assert(in.shape() == out.shape()); - std::vector ostrides{out.strides().begin(), out.strides().end()}; return copy_gpu_inplace( - in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s); + in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s); } void fill_gpu(const array& val, array& out, const Stream& s) { diff --git a/mlx/backend/metal/copy.h b/mlx/backend/metal/copy.h index 3042c714e..2568f9afa 100644 --- a/mlx/backend/metal/copy.h +++ b/mlx/backend/metal/copy.h @@ -8,13 +8,12 @@ namespace mlx::core { // Generic copy inplace -template void copy_gpu_inplace( const array& in, array& out, - const std::vector& data_shape, - const std::vector& i_strides, - const std::vector& o_strides, + const Shape& data_shape, + const Strides& i_strides, + const Strides& o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, @@ -32,7 +31,7 @@ void copy_gpu_inplace( void copy_gpu_inplace( const array& in, array& out, - const std::vector& istride, + const Strides& istride, int64_t ioffset, CopyType ctype, const Stream& s); diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index e6da71fe1..83aa18b88 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -363,7 +363,7 @@ void multi_upload_bluestein_fft( auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n); // Broadcast w_q and w_k to the batch size - std::vector b_strides(in.ndim(), 0); + Strides b_strides(in.ndim(), 0); b_strides[axis] = 1; array w_k_broadcast({}, complex64, nullptr, {}); array w_q_broadcast({}, complex64, nullptr, {}); @@ -386,8 +386,8 @@ void multi_upload_bluestein_fft( copies.push_back(slice_temp); copies.push_back(conj_temp); - std::vector rstarts(in.ndim(), 0); - std::vector rstrides(in.ndim(), 1); + Shape rstarts(in.ndim(), 0); + Shape rstrides(in.ndim(), 1); rstarts[axis] = in.shape(axis) - back_offset; rstrides[axis] = -1; unary_op_gpu({in}, conj_temp, "Conjugate", s); @@ -431,19 +431,19 @@ void multi_upload_bluestein_fft( s); int offset = plan.bluestein_n - (2 * n - 1); - std::vector starts(in.ndim(), 0); - std::vector strides(in.ndim(), 1); + Shape starts(in.ndim(), 0); + Shape strides(in.ndim(), 1); starts[axis] = plan.bluestein_n - offset - n; slice_gpu(pad_temp1, temp, starts, strides, s); binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s); if (real && !inverse) { - std::vector rstarts(in.ndim(), 0); - std::vector rstrides(in.ndim(), 1); + Shape rstarts(in.ndim(), 0); + Shape rstrides(in.ndim(), 1); slice_gpu(temp1, out, rstarts, strides, s); } else if (real && inverse) { - std::vector b_strides(in.ndim(), 0); + Strides b_strides(in.ndim(), 0); auto inv_n = array({1.0f / n}, {1}, float32); array temp_float(out.shape(), out.dtype(), nullptr, {}); copies.push_back(temp_float); @@ -531,8 +531,8 @@ void fft_op( return x; } else { array x_copy(x.shape(), x.dtype(), nullptr, {}); - std::vector strides; - size_t cur_stride = x.shape(axis); + Strides strides; + int64_t cur_stride = x.shape(axis); for (int a = 0; a < x.ndim(); a++) { if (a == axis) { strides.push_back(1); @@ -777,7 +777,7 @@ void nd_fft_op( // Mirror np.fft.(i)rfftn and perform a real transform // only on the final axis. bool step_real = (real && index == axes.size() - 1); - int step_shape = inverse ? out.shape(axis) : in.shape(axis); + auto step_shape = inverse ? out.shape(axis) : in.shape(axis); const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2]; array& out_arr = i == 0 ? out : temp_arrs[i % 2]; fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index bea3e8e57..85d7a711a 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -65,7 +65,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { idx_type_name, nidx, idx_ndim, - large ? "size_t" : "uint"); + large ? "int64_t" : "uint"); std::string lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { @@ -86,7 +86,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { idx_args, idx_arr, idx_ndim, - large ? "size_t" : "uint"); + large ? "int64_t" : "uint"); return kernel_source; }); @@ -246,7 +246,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { nidx, upd_contig ? "updc_true" : "updc_false", nwork, - large ? "size_t" : "uint"); + large ? "int64_t" : "uint"); std::string lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { @@ -290,7 +290,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { idx_arr, upd_contig, nwork, - large ? "size_t" : "uint"); + large ? "int64_t" : "uint"); return kernel_source; }); @@ -312,8 +312,8 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { upd_size *= upd.shape(i); } // Collect all idx shapes and strides into one place - std::vector idx_shapes; - std::vector idx_strides; + Shape idx_shapes; + Strides idx_strides; // To access .data() use char instead of bool // bool is 1 byte in Metal so this is safe std::vector idx_contigs; @@ -332,7 +332,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { if (upd_ndim == 0) { // Need placeholders so Metal doesn't compalain int shape_ = 0; - size_t stride_ = 0; + int64_t stride_ = 0; compute_encoder.set_bytes(shape_, 3); compute_encoder.set_bytes(stride_, 4); } else { @@ -347,7 +347,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { if (out_ndim == 0) { // Need placeholders so Metal doesn't compalain int shape_ = 0; - size_t stride_ = 0; + int64_t stride_ = 0; compute_encoder.set_bytes(shape_, 7); compute_encoder.set_bytes(stride_, 8); } else { diff --git a/mlx/backend/metal/jit/gemv_masked.h b/mlx/backend/metal/jit/gemv_masked.h index deae78865..b83ad881f 100644 --- a/mlx/backend/metal/jit/gemv_masked.h +++ b/mlx/backend/metal/jit/gemv_masked.h @@ -11,13 +11,13 @@ gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn const constant int& marix_ld [[buffer(6)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], - const constant size_t* vector_batch_stride [[buffer(11)]], - const constant size_t* matrix_batch_stride [[buffer(12)]], + const constant int64_t* vector_batch_stride [[buffer(11)]], + const constant int64_t* matrix_batch_stride [[buffer(12)]], const device {outm_t}* out_mask [[buffer(20)]], const device {opm_t}* mat_mask [[buffer(21)]], const device {opm_t}* vec_mask [[buffer(22)]], const constant int* mask_strides [[buffer(23)]], - const constant size_t* mask_batch_strides [[buffer(24)]], + const constant int64_t* mask_batch_strides [[buffer(24)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], diff --git a/mlx/backend/metal/jit/indexing.h b/mlx/backend/metal/jit/indexing.h index eacad0a51..57bafc229 100644 --- a/mlx/backend/metal/jit/indexing.h +++ b/mlx/backend/metal/jit/indexing.h @@ -5,12 +5,12 @@ constexpr std::string_view gather_kernels = R"( const device {1}* src [[buffer(0)]], device {1}* out [[buffer(1)]], const constant int* src_shape [[buffer(2)]], - const constant size_t* src_strides [[buffer(3)]], + const constant int64_t* src_strides [[buffer(3)]], const constant size_t& src_ndim [[buffer(4)]], const constant int* slice_sizes [[buffer(5)]], const constant int* axes [[buffer(6)]], const constant int* idx_shapes [[buffer(7)]], - const constant size_t* idx_strides [[buffer(8)]], + const constant int64_t* idx_strides [[buffer(8)]], const constant bool* idx_contigs [[buffer(9)]], const constant int& idx_ndim [[buffer(10)]], {4} @@ -38,15 +38,15 @@ constexpr std::string_view scatter_kernels = R"( const device {1}* updates [[buffer(1)]], device mlx_atomic<{1}>* out [[buffer(2)]], const constant int* upd_shape [[buffer(3)]], - const constant size_t* upd_strides [[buffer(4)]], + const constant int64_t* upd_strides [[buffer(4)]], const constant size_t& upd_ndim [[buffer(5)]], const constant size_t& upd_size [[buffer(6)]], const constant int* out_shape [[buffer(7)]], - const constant size_t* out_strides [[buffer(8)]], + const constant int64_t* out_strides [[buffer(8)]], const constant size_t& out_ndim [[buffer(9)]], const constant int* axes [[buffer(10)]], const constant int* idx_shapes [[buffer(11)]], - const constant size_t* idx_strides [[buffer(12)]], + const constant int64_t* idx_strides [[buffer(12)]], const constant bool* idx_contigs [[buffer(13)]], const constant int& idx_ndim [[buffer(14)]], const constant size_t& idx_size [[buffer(15)]], diff --git a/mlx/backend/metal/jit/steel_gemm.h b/mlx/backend/metal/jit/steel_gemm.h index d1a2378bf..85ddc449a 100644 --- a/mlx/backend/metal/jit/steel_gemm.h +++ b/mlx/backend/metal/jit/steel_gemm.h @@ -10,12 +10,12 @@ template [[host_name("{name}")]] const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], + const constant int64_t* batch_strides [[buffer(7)]], const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], - const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], + const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]], const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], @@ -43,7 +43,7 @@ block_masked_gemm< device {itype}* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], + const constant int64_t* batch_strides [[buffer(7)]], const device {outmasktype}* out_mask [[buffer(10)]], const device {opmasktype}* lhs_mask [[buffer(11)]], const device {opmasktype}* rhs_mask [[buffer(12)]], diff --git a/mlx/backend/metal/kernels/arg_reduce.metal b/mlx/backend/metal/kernels/arg_reduce.metal index fa32dec4f..7f1075ad9 100644 --- a/mlx/backend/metal/kernels/arg_reduce.metal +++ b/mlx/backend/metal/kernels/arg_reduce.metal @@ -75,10 +75,10 @@ template const device T* in [[buffer(0)]], device uint32_t* out [[buffer(1)]], const constant int* shape [[buffer(2)]], - const constant size_t* in_strides [[buffer(3)]], - const constant size_t* out_strides [[buffer(4)]], + const constant int64_t* in_strides [[buffer(3)]], + const constant int64_t* out_strides [[buffer(4)]], const constant size_t& ndim [[buffer(5)]], - const constant size_t& axis_stride [[buffer(6)]], + const constant int64_t& axis_stride [[buffer(6)]], const constant size_t& axis_size [[buffer(7)]], uint gid [[thread_position_in_grid]], uint lid [[thread_position_in_threadgroup]], diff --git a/mlx/backend/metal/kernels/binary.h b/mlx/backend/metal/kernels/binary.h index d9f0a9710..91a02c818 100644 --- a/mlx/backend/metal/kernels/binary.h +++ b/mlx/backend/metal/kernels/binary.h @@ -43,7 +43,7 @@ template device U* c, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - size_t offset = index.x + grid_dim.x * size_t(index.y); + int64_t offset = index.x + grid_dim.x * int64_t(index.y); c[offset] = Op()(a[0], b[offset]); } @@ -54,7 +54,7 @@ template device U* c, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - size_t offset = index.x + grid_dim.x * size_t(index.y); + int64_t offset = index.x + grid_dim.x * int64_t(index.y); c[offset] = Op()(a[offset], b[0]); } @@ -65,49 +65,49 @@ template device U* c, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - size_t offset = index.x + grid_dim.x * size_t(index.y); + int64_t offset = index.x + grid_dim.x * int64_t(index.y); c[offset] = Op()(a[offset], b[offset]); } -template +template [[kernel]] void binary_g_nd1( device const T* a, device const T* b, device U* c, - constant const size_t& a_stride, - constant const size_t& b_stride, + constant const int64_t& a_stride, + constant const int64_t& b_stride, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); c[index] = Op()(a[a_idx], b[b_idx]); } -template +template [[kernel]] void binary_g_nd2( device const T* a, device const T* b, device U* c, - constant const size_t a_strides[2], - constant const size_t b_strides[2], + constant const int64_t a_strides[2], + constant const int64_t b_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; c[out_idx] = Op()(a[a_idx], b[b_idx]); } -template +template [[kernel]] void binary_g_nd3( device const T* a, device const T* b, device U* c, - constant const size_t a_strides[3], - constant const size_t b_strides[3], + constant const int64_t a_strides[3], + constant const int64_t b_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); c[out_idx] = Op()(a[a_idx], b[b_idx]); } @@ -117,18 +117,18 @@ template < typename U, typename Op, int N = 1, - typename IdxT = size_t> + typename IdxT = int64_t> [[kernel]] void binary_g( device const T* a, device const T* b, device U* c, constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, + constant const int64_t* a_strides, + constant const int64_t* b_strides, constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); auto xshape = shape[ndim - 1]; IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); diff --git a/mlx/backend/metal/kernels/binary_two.h b/mlx/backend/metal/kernels/binary_two.h index 17dfb0f62..8f6b3392d 100644 --- a/mlx/backend/metal/kernels/binary_two.h +++ b/mlx/backend/metal/kernels/binary_two.h @@ -56,7 +56,7 @@ template device U* d, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - size_t offset = index.x + grid_dim.x * size_t(index.y); + auto offset = index.x + grid_dim.x * int64_t(index.y); auto out = Op()(a[0], b[offset]); c[offset] = out[0]; d[offset] = out[1]; @@ -70,7 +70,7 @@ template device U* d, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - size_t offset = index.x + grid_dim.x * size_t(index.y); + auto offset = index.x + grid_dim.x * int64_t(index.y); auto out = Op()(a[offset], b[0]); c[offset] = out[0]; d[offset] = out[1]; @@ -84,58 +84,58 @@ template device U* d, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - size_t offset = index.x + grid_dim.x * size_t(index.y); + auto offset = index.x + grid_dim.x * int64_t(index.y); auto out = Op()(a[offset], b[offset]); c[offset] = out[0]; d[offset] = out[1]; } -template +template [[kernel]] void binary_g_nd1( device const T* a, device const T* b, device U* c, device U* d, - constant const size_t& a_stride, - constant const size_t& b_stride, + constant const int64_t& a_stride, + constant const int64_t& b_stride, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); auto out = Op()(a[a_idx], b[b_idx]); c[index] = out[0]; d[index] = out[1]; } -template +template [[kernel]] void binary_g_nd2( device const T* a, device const T* b, device U* c, device U* d, - constant const size_t a_strides[2], - constant const size_t b_strides[2], + constant const int64_t a_strides[2], + constant const int64_t b_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; d[out_idx] = out[1]; } -template +template [[kernel]] void binary_g_nd3( device const T* a, device const T* b, device U* c, device U* d, - constant const size_t a_strides[3], - constant const size_t b_strides[3], + constant const int64_t a_strides[3], + constant const int64_t b_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); auto out = Op()(a[a_idx], b[b_idx]); c[out_idx] = out[0]; @@ -147,19 +147,19 @@ template < typename U, typename Op, int N = 1, - typename IdxT = size_t> + typename IdxT = int64_t> [[kernel]] void binary_g( device const T* a, device const T* b, device U* c, device U* d, constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, + constant const int64_t* a_strides, + constant const int64_t* b_strides, constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); auto xshape = shape[ndim - 1]; IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index 7b664b3a4..dddcda366 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -22,7 +22,7 @@ template device U* dst [[buffer(1)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - size_t offset = index.x + grid_dim.x * size_t(index.y); + auto offset = index.x + grid_dim.x * int64_t(index.y); dst[offset] = static_cast(src[0]); } @@ -32,7 +32,7 @@ template device U* dst [[buffer(1)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - size_t offset = index.x + grid_dim.x * size_t(index.y); + auto offset = index.x + grid_dim.x * int64_t(index.y); dst[offset] = static_cast(src[offset]); } @@ -42,7 +42,7 @@ template device U* dst [[buffer(1)]], constant const int64_t& src_stride [[buffer(3)]], uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); + auto src_idx = elem_to_loc_1(index, src_stride); dst[index] = static_cast(src[src_idx]); } @@ -53,7 +53,7 @@ template constant const int64_t* src_strides [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); + auto src_idx = elem_to_loc_2(index, src_strides); IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y; dst[dst_idx] = static_cast(src[src_idx]); } @@ -65,7 +65,7 @@ template constant const int64_t* src_strides [[buffer(3)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); + auto src_idx = elem_to_loc_3(index, src_strides); IdxT dst_idx = index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z); dst[dst_idx] = static_cast(src[src_idx]); @@ -80,7 +80,7 @@ template constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc( + auto src_idx = elem_to_loc( {N * index.x, index.y, index.z}, src_shape, src_strides, ndim); if (N == 1) { IdxT dst_idx = @@ -104,8 +104,8 @@ template constant const int64_t& src_stride [[buffer(3)]], constant const int64_t& dst_stride [[buffer(4)]], uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); - auto dst_idx = elem_to_loc_1(index, dst_stride); + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); dst[dst_idx] = static_cast(src[src_idx]); } @@ -116,8 +116,8 @@ template constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], uint2 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - auto dst_idx = elem_to_loc_2(index, dst_strides); + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); dst[dst_idx] = static_cast(src[src_idx]); } @@ -128,8 +128,8 @@ template constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* dst_strides [[buffer(4)]], uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - auto dst_idx = elem_to_loc_3(index, dst_strides); + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); dst[dst_idx] = static_cast(src[src_idx]); } @@ -142,7 +142,7 @@ template constant const int64_t* dst_strides [[buffer(4)]], constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]]) { - auto idx = elem_to_loc_2_nd( + auto idx = elem_to_loc_2_nd( {N * index.x, index.y, index.z}, src_shape, src_strides, diff --git a/mlx/backend/metal/kernels/gather.h b/mlx/backend/metal/kernels/gather.h index b38ab6283..472e497c0 100644 --- a/mlx/backend/metal/kernels/gather.h +++ b/mlx/backend/metal/kernels/gather.h @@ -9,7 +9,7 @@ METAL_FUNC void gather_impl( const device T* src [[buffer(0)]], device T* out [[buffer(1)]], const constant int* src_shape [[buffer(2)]], - const constant size_t* src_strides [[buffer(3)]], + const constant int64_t* src_strides [[buffer(3)]], const constant size_t& src_ndim [[buffer(4)]], const constant int* slice_sizes [[buffer(5)]], const constant int* axes [[buffer(6)]], @@ -27,7 +27,7 @@ METAL_FUNC void gather_impl( idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); idx_loc += indices.row_contiguous[i] ? index.y - : elem_to_loc( + : elem_to_loc( index.y, &indices.shapes[indices.ndim * i + 1], &indices.strides[indices.ndim * i + 1], @@ -39,7 +39,7 @@ METAL_FUNC void gather_impl( } auto src_offset = - elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); + elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); LocT out_idx = index.z; if (IDX_NDIM == 1) { diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index 1776c54e2..28cadd50a 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -436,9 +436,9 @@ template < const constant float& beta [[buffer(8)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], - const constant size_t* vector_batch_stride [[buffer(11)]], - const constant size_t* matrix_batch_stride [[buffer(12)]], - const constant size_t* bias_batch_stride [[buffer(13)]], + const constant int64_t* vector_batch_stride [[buffer(11)]], + const constant int64_t* matrix_batch_stride [[buffer(12)]], + const constant int64_t* bias_batch_stride [[buffer(13)]], const constant int& bias_stride [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], @@ -486,31 +486,21 @@ template < simd_lid); } -#define instantiate_gemv_helper( \ - name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ - template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \ - "_tm" #tm "_tn" #tn "_nc" #nc \ - "_axpby" #axpby)]] [[kernel]] void \ - gemv( \ - const device itype* mat [[buffer(0)]], \ - const device itype* in_vec [[buffer(1)]], \ - const device itype* bias [[buffer(2)]], \ - device itype* out_vec [[buffer(3)]], \ - const constant int& in_vec_size [[buffer(4)]], \ - const constant int& out_vec_size [[buffer(5)]], \ - const constant int& marix_ld [[buffer(6)]], \ - const constant float& alpha [[buffer(7)]], \ - const constant float& beta [[buffer(8)]], \ - const constant int& batch_ndim [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* vector_batch_stride [[buffer(11)]], \ - const constant size_t* matrix_batch_stride [[buffer(12)]], \ - const constant size_t* bias_batch_stride [[buffer(13)]], \ - const constant int& bias_stride [[buffer(14)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); +#define instantiate_gemv_helper( \ + name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ + instantiate_kernel( \ + "gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ + "_tn" #tn "_nc" #nc "_axpby" #axpby, \ + gemv, \ + itype, \ + bm, \ + bn, \ + sm, \ + sn, \ + tm, \ + tn, \ + nc, \ + axpby) // clang-format off #define instantiate_gemv(name, itype, bm, bn, tm, tn) \ @@ -549,13 +539,13 @@ template < const constant float& beta [[buffer(8)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], - const constant size_t* index_batch_strides [[buffer(11)]], + const constant int64_t* index_batch_strides [[buffer(11)]], const constant int& vector_batch_ndim [[buffer(12)]], const constant int* vector_batch_shape [[buffer(13)]], - const constant size_t* vector_batch_stride [[buffer(14)]], + const constant int64_t* vector_batch_stride [[buffer(14)]], const constant int& matrix_batch_ndim [[buffer(15)]], const constant int* matrix_batch_shape [[buffer(16)]], - const constant size_t* matrix_batch_stride [[buffer(17)]], + const constant int64_t* matrix_batch_stride [[buffer(17)]], const constant uint32_t* vec_indices [[buffer(18)]], const constant uint32_t* mat_indices [[buffer(19)]], uint3 tid [[threadgroup_position_in_grid]], @@ -571,8 +561,8 @@ template < // Update batch offsets if (batch_ndim > 1) { - const constant size_t* veci_bstrides = index_batch_strides; - const constant size_t* mati_bstrides = index_batch_strides + batch_ndim; + const constant auto* veci_bstrides = index_batch_strides; + const constant auto* mati_bstrides = index_batch_strides + batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); @@ -619,37 +609,14 @@ template < simd_lid); } -#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \ - template [[host_name("gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ - "_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \ - gemv_gather( \ - const device itype* mat [[buffer(0)]], \ - const device itype* in_vec [[buffer(1)]], \ - const device itype* bias [[buffer(2)]], \ - device itype* out_vec [[buffer(3)]], \ - const constant int& in_vec_size [[buffer(4)]], \ - const constant int& out_vec_size [[buffer(5)]], \ - const constant int& marix_ld [[buffer(6)]], \ - const constant float& alpha [[buffer(7)]], \ - const constant float& beta [[buffer(8)]], \ - const constant int& batch_ndim [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* index_batch_strides [[buffer(11)]], \ - const constant int& vector_batch_ndim [[buffer(12)]], \ - const constant int* vector_batch_shape [[buffer(13)]], \ - const constant size_t* vector_batch_stride [[buffer(14)]], \ - const constant int& matrix_batch_ndim [[buffer(15)]], \ - const constant int* matrix_batch_shape [[buffer(16)]], \ - const constant size_t* matrix_batch_stride [[buffer(17)]], \ - const constant uint32_t* vec_indices [[buffer(18)]], \ - const constant uint32_t* mat_indices [[buffer(19)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); - // clang-format off -#define instantiate_gemv_bs_blocks(name, itype) \ +#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \ + instantiate_kernel( \ + "gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ + "_sn" #sn "_tm" #tm "_tn" #tn, \ + gemv_gather, itype, bm, bn, sm, sn, tm, tn) + +#define instantiate_gemv_bs_blocks(name, itype) \ instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \ instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \ instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on @@ -684,9 +651,9 @@ template < const constant float& beta [[buffer(8)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], - const constant size_t* vector_batch_stride [[buffer(11)]], - const constant size_t* matrix_batch_stride [[buffer(12)]], - const constant size_t* bias_batch_stride [[buffer(13)]], + const constant int64_t* vector_batch_stride [[buffer(11)]], + const constant int64_t* matrix_batch_stride [[buffer(12)]], + const constant int64_t* bias_batch_stride [[buffer(13)]], const constant int& bias_stride [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], @@ -734,33 +701,14 @@ template < simd_lid); } -#define instantiate_gemv_t_helper( \ - name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ - template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \ - "_tm" #tm "_tn" #tn "_nc" #nc \ - "_axpby" #axpby)]] [[kernel]] void \ - gemv_t( \ - const device itype* mat [[buffer(0)]], \ - const device itype* in_vec [[buffer(1)]], \ - const device itype* bias [[buffer(2)]], \ - device itype* out_vec [[buffer(3)]], \ - const constant int& in_vec_size [[buffer(4)]], \ - const constant int& out_vec_size [[buffer(5)]], \ - const constant int& marix_ld [[buffer(6)]], \ - const constant float& alpha [[buffer(7)]], \ - const constant float& beta [[buffer(8)]], \ - const constant int& batch_ndim [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* vector_batch_stride [[buffer(11)]], \ - const constant size_t* matrix_batch_stride [[buffer(12)]], \ - const constant size_t* bias_batch_stride [[buffer(13)]], \ - const constant int& bias_stride [[buffer(14)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); - // clang-format off +#define instantiate_gemv_t_helper( \ + name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ + instantiate_kernel( \ + "gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \ + "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \ + gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby) + #define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \ @@ -800,13 +748,13 @@ template < const constant float& beta [[buffer(8)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], - const constant size_t* index_batch_strides [[buffer(11)]], + const constant int64_t* index_batch_strides [[buffer(11)]], const constant int& vector_batch_ndim [[buffer(12)]], const constant int* vector_batch_shape [[buffer(13)]], - const constant size_t* vector_batch_stride [[buffer(14)]], + const constant int64_t* vector_batch_stride [[buffer(14)]], const constant int& matrix_batch_ndim [[buffer(15)]], const constant int* matrix_batch_shape [[buffer(16)]], - const constant size_t* matrix_batch_stride [[buffer(17)]], + const constant int64_t* matrix_batch_stride [[buffer(17)]], const constant uint32_t* vec_indices [[buffer(18)]], const constant uint32_t* mat_indices [[buffer(19)]], uint3 tid [[threadgroup_position_in_grid]], @@ -822,8 +770,8 @@ template < // Update batch offsets if (batch_ndim > 1) { - const constant size_t* veci_bstrides = index_batch_strides; - const constant size_t* mati_bstrides = index_batch_strides + batch_ndim; + const constant auto* veci_bstrides = index_batch_strides; + const constant auto* mati_bstrides = index_batch_strides + batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); @@ -870,36 +818,14 @@ template < simd_lid); } -#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \ - template [[host_name("gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ - "_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \ - gemv_t_gather( \ - const device itype* mat [[buffer(0)]], \ - const device itype* in_vec [[buffer(1)]], \ - const device itype* bias [[buffer(2)]], \ - device itype* out_vec [[buffer(3)]], \ - const constant int& in_vec_size [[buffer(4)]], \ - const constant int& out_vec_size [[buffer(5)]], \ - const constant int& marix_ld [[buffer(6)]], \ - const constant float& alpha [[buffer(7)]], \ - const constant float& beta [[buffer(8)]], \ - const constant int& batch_ndim [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* index_batch_strides [[buffer(11)]], \ - const constant int& vector_batch_ndim [[buffer(12)]], \ - const constant int* vector_batch_shape [[buffer(13)]], \ - const constant size_t* vector_batch_stride [[buffer(14)]], \ - const constant int& matrix_batch_ndim [[buffer(15)]], \ - const constant int* matrix_batch_shape [[buffer(16)]], \ - const constant size_t* matrix_batch_stride [[buffer(17)]], \ - const constant uint32_t* vec_indices [[buffer(18)]], \ - const constant uint32_t* mat_indices [[buffer(19)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); - // clang-format off +#define instantiate_gemv_t_bs_helper( \ + nm, itype, bm, bn, sm, sn, tm, tn) \ + instantiate_kernel( \ + "gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ + "_sn" #sn "_tm" #tm "_tn" #tn, \ + gemv_t_gather, itype, bm, bn, sm, sn, tm, tn) + #define instantiate_gemv_t_bs_blocks(name, itype) \ instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \ instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \ diff --git a/mlx/backend/metal/kernels/gemv_masked.h b/mlx/backend/metal/kernels/gemv_masked.h index 1dd436e32..48acf1d61 100644 --- a/mlx/backend/metal/kernels/gemv_masked.h +++ b/mlx/backend/metal/kernels/gemv_masked.h @@ -642,13 +642,13 @@ template < const constant int& marix_ld [[buffer(6)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], - const constant size_t* vector_batch_stride [[buffer(11)]], - const constant size_t* matrix_batch_stride [[buffer(12)]], + const constant int64_t* vector_batch_stride [[buffer(11)]], + const constant int64_t* matrix_batch_stride [[buffer(12)]], const device out_mask_t* out_mask [[buffer(20)]], const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* vec_mask [[buffer(22)]], const constant int* mask_strides [[buffer(23)]], - const constant size_t* mask_batch_strides [[buffer(24)]], + const constant int64_t* mask_batch_strides [[buffer(24)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -673,8 +673,8 @@ template < } if (has_operand_mask) { - const constant size_t* mask_strides_mat = mask_batch_strides; - const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; + const constant auto* mask_strides_mat = mask_batch_strides; + const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); @@ -742,13 +742,13 @@ template < const constant int& marix_ld [[buffer(6)]], const constant int& batch_ndim [[buffer(9)]], const constant int* batch_shape [[buffer(10)]], - const constant size_t* vector_batch_stride [[buffer(11)]], - const constant size_t* matrix_batch_stride [[buffer(12)]], + const constant int64_t* vector_batch_stride [[buffer(11)]], + const constant int64_t* matrix_batch_stride [[buffer(12)]], const device out_mask_t* out_mask [[buffer(20)]], const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* vec_mask [[buffer(22)]], const constant int* mask_strides [[buffer(23)]], - const constant size_t* mask_batch_strides [[buffer(24)]], + const constant int64_t* mask_batch_strides [[buffer(24)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -773,8 +773,8 @@ template < } if (has_operand_mask) { - const constant size_t* mask_strides_mat = mask_batch_strides; - const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; + const constant auto* mask_strides_mat = mask_batch_strides; + const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); diff --git a/mlx/backend/metal/kernels/gemv_masked.metal b/mlx/backend/metal/kernels/gemv_masked.metal index db787e7fc..394250e29 100644 --- a/mlx/backend/metal/kernels/gemv_masked.metal +++ b/mlx/backend/metal/kernels/gemv_masked.metal @@ -10,29 +10,11 @@ #define instantiate_gemv_helper( \ outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ - template [[host_name("gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ - "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ - "_tn" #tn "_nc" #nc)]] [[kernel]] void \ - gemv_masked( \ - const device itype* mat [[buffer(0)]], \ - const device itype* in_vec [[buffer(1)]], \ - device itype* out_vec [[buffer(3)]], \ - const constant int& in_vec_size [[buffer(4)]], \ - const constant int& out_vec_size [[buffer(5)]], \ - const constant int& marix_ld [[buffer(6)]], \ - const constant int& batch_ndim [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* vector_batch_stride [[buffer(11)]], \ - const constant size_t* matrix_batch_stride [[buffer(12)]], \ - const device outm_t* out_mask [[buffer(20)]], \ - const device opm_t* mat_mask [[buffer(21)]], \ - const device opm_t* vec_mask [[buffer(22)]], \ - const constant int* mask_strides [[buffer(23)]], \ - const constant size_t* mask_batch_strides [[buffer(24)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); + instantiate_kernel( \ + "gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ + "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ + "_tn" #tn "_nc" #nc, \ + gemv_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc) #define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ @@ -61,29 +43,11 @@ instantiate_gemv_blocks(bfloat16, bfloat16_t); #define instantiate_gemv_t_helper( \ outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ - template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ - "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ - "_tn" #tn "_nc" #nc)]] [[kernel]] void \ - gemv_t_masked( \ - const device itype* mat [[buffer(0)]], \ - const device itype* in_vec [[buffer(1)]], \ - device itype* out_vec [[buffer(3)]], \ - const constant int& in_vec_size [[buffer(4)]], \ - const constant int& out_vec_size [[buffer(5)]], \ - const constant int& marix_ld [[buffer(6)]], \ - const constant int& batch_ndim [[buffer(9)]], \ - const constant int* batch_shape [[buffer(10)]], \ - const constant size_t* vector_batch_stride [[buffer(11)]], \ - const constant size_t* matrix_batch_stride [[buffer(12)]], \ - const device outm_t* out_mask [[buffer(20)]], \ - const device opm_t* mat_mask [[buffer(21)]], \ - const device opm_t* vec_mask [[buffer(22)]], \ - const constant int* mask_strides [[buffer(23)]], \ - const constant size_t* mask_batch_strides [[buffer(24)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); + instantiate_kernel( \ + "gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ + "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ + "_tn" #tn "_nc" #nc, \ + gemv_t_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc) #define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ diff --git a/mlx/backend/metal/kernels/indexing.h b/mlx/backend/metal/kernels/indexing.h index 05bef96b6..2a4b4f929 100644 --- a/mlx/backend/metal/kernels/indexing.h +++ b/mlx/backend/metal/kernels/indexing.h @@ -8,7 +8,7 @@ template struct Indices { const array buffers; const constant int* shapes; - const constant size_t* strides; + const constant int64_t* strides; const constant bool* row_contiguous; const int ndim; }; diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index cbff318e6..33eec4910 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -1219,12 +1219,12 @@ METAL_FUNC void adjust_matrix_offsets( int output_stride, const constant int& x_batch_ndims, const constant int* x_shape, - const constant size_t* x_strides, + const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, - const constant size_t* w_strides, - const constant size_t* s_strides, - const constant size_t* b_strides, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx = tid.z; @@ -1260,16 +1260,16 @@ METAL_FUNC void adjust_matrix_offsets( int output_stride, const constant int& batch_ndims, const constant int* batch_shape, - const constant size_t* lhs_strides, - const constant size_t* rhs_strides, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, const constant int& x_batch_ndims, const constant int* x_shape, - const constant size_t* x_strides, + const constant int64_t* x_strides, const constant int& w_batch_ndims, const constant int* w_shape, - const constant size_t* w_strides, - const constant size_t* s_strides, - const constant size_t* b_strides, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, uint3 tid [[threadgroup_position_in_grid]]) { // Set the input/output matrices uint32_t x_idx; @@ -1313,12 +1313,12 @@ template const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], - const constant size_t* x_strides [[buffer(9)]], + const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], - const constant size_t* w_strides [[buffer(12)]], - const constant size_t* s_strides [[buffer(13)]], - const constant size_t* b_strides [[buffer(14)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint quad_gid [[quadgroup_index_in_threadgroup]], uint quad_lid [[thread_index_in_quadgroup]]) { @@ -1364,12 +1364,12 @@ template const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], - const constant size_t* x_strides [[buffer(9)]], + const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], - const constant size_t* w_strides [[buffer(12)]], - const constant size_t* s_strides [[buffer(13)]], - const constant size_t* b_strides [[buffer(14)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1415,12 +1415,12 @@ template const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], - const constant size_t* x_strides [[buffer(9)]], + const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], - const constant size_t* w_strides [[buffer(12)]], - const constant size_t* s_strides [[buffer(13)]], - const constant size_t* b_strides [[buffer(14)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1466,12 +1466,12 @@ template const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], - const constant size_t* x_strides [[buffer(9)]], + const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], - const constant size_t* w_strides [[buffer(12)]], - const constant size_t* s_strides [[buffer(13)]], - const constant size_t* b_strides [[buffer(14)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1517,12 +1517,12 @@ template const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], - const constant size_t* x_strides [[buffer(9)]], + const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], - const constant size_t* w_strides [[buffer(12)]], - const constant size_t* s_strides [[buffer(13)]], - const constant size_t* b_strides [[buffer(14)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], const constant int& final_block_size [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1581,12 +1581,12 @@ template < const constant int& M [[buffer(7)]], const constant int& x_batch_ndims [[buffer(8)]], const constant int* x_shape [[buffer(9)]], - const constant size_t* x_strides [[buffer(10)]], + const constant int64_t* x_strides [[buffer(10)]], const constant int& w_batch_ndims [[buffer(11)]], const constant int* w_shape [[buffer(12)]], - const constant size_t* w_strides [[buffer(13)]], - const constant size_t* s_strides [[buffer(14)]], - const constant size_t* b_strides [[buffer(15)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1639,12 +1639,12 @@ template < const constant int& M [[buffer(7)]], const constant int& x_batch_ndims [[buffer(8)]], const constant int* x_shape [[buffer(9)]], - const constant size_t* x_strides [[buffer(10)]], + const constant int64_t* x_strides [[buffer(10)]], const constant int& w_batch_ndims [[buffer(11)]], const constant int* w_shape [[buffer(12)]], - const constant size_t* w_strides [[buffer(13)]], - const constant size_t* s_strides [[buffer(14)]], - const constant size_t* b_strides [[buffer(15)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1691,18 +1691,18 @@ template const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], - const constant size_t* x_strides [[buffer(9)]], + const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], - const constant size_t* w_strides [[buffer(12)]], - const constant size_t* s_strides [[buffer(13)]], - const constant size_t* b_strides [[buffer(14)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], const constant int& batch_ndims [[buffer(15)]], const constant int* batch_shape [[buffer(16)]], const device uint32_t* lhs_indices [[buffer(17)]], const device uint32_t* rhs_indices [[buffer(18)]], - const constant size_t* lhs_strides [[buffer(19)]], - const constant size_t* rhs_strides [[buffer(20)]], + const constant int64_t* lhs_strides [[buffer(19)]], + const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1752,18 +1752,18 @@ template const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], - const constant size_t* x_strides [[buffer(9)]], + const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], - const constant size_t* w_strides [[buffer(12)]], - const constant size_t* s_strides [[buffer(13)]], - const constant size_t* b_strides [[buffer(14)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], const constant int& batch_ndims [[buffer(15)]], const constant int* batch_shape [[buffer(16)]], const device uint32_t* lhs_indices [[buffer(17)]], const device uint32_t* rhs_indices [[buffer(18)]], - const constant size_t* lhs_strides [[buffer(19)]], - const constant size_t* rhs_strides [[buffer(20)]], + const constant int64_t* lhs_strides [[buffer(19)]], + const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1813,18 +1813,18 @@ template const constant int& out_vec_size [[buffer(6)]], const constant int& x_batch_ndims [[buffer(7)]], const constant int* x_shape [[buffer(8)]], - const constant size_t* x_strides [[buffer(9)]], + const constant int64_t* x_strides [[buffer(9)]], const constant int& w_batch_ndims [[buffer(10)]], const constant int* w_shape [[buffer(11)]], - const constant size_t* w_strides [[buffer(12)]], - const constant size_t* s_strides [[buffer(13)]], - const constant size_t* b_strides [[buffer(14)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], const constant int& batch_ndims [[buffer(15)]], const constant int* batch_shape [[buffer(16)]], const device uint32_t* lhs_indices [[buffer(17)]], const device uint32_t* rhs_indices [[buffer(18)]], - const constant size_t* lhs_strides [[buffer(19)]], - const constant size_t* rhs_strides [[buffer(20)]], + const constant int64_t* lhs_strides [[buffer(19)]], + const constant int64_t* rhs_strides [[buffer(20)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { @@ -1882,18 +1882,18 @@ template < const constant int& M [[buffer(7)]], const constant int& x_batch_ndims [[buffer(8)]], const constant int* x_shape [[buffer(9)]], - const constant size_t* x_strides [[buffer(10)]], + const constant int64_t* x_strides [[buffer(10)]], const constant int& w_batch_ndims [[buffer(11)]], const constant int* w_shape [[buffer(12)]], - const constant size_t* w_strides [[buffer(13)]], - const constant size_t* s_strides [[buffer(14)]], - const constant size_t* b_strides [[buffer(15)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], const constant int& batch_ndims [[buffer(16)]], const constant int* batch_shape [[buffer(17)]], const device uint32_t* lhs_indices [[buffer(18)]], const device uint32_t* rhs_indices [[buffer(19)]], - const constant size_t* lhs_strides [[buffer(20)]], - const constant size_t* rhs_strides [[buffer(21)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -1949,18 +1949,18 @@ template < const constant int& M [[buffer(7)]], const constant int& x_batch_ndims [[buffer(8)]], const constant int* x_shape [[buffer(9)]], - const constant size_t* x_strides [[buffer(10)]], + const constant int64_t* x_strides [[buffer(10)]], const constant int& w_batch_ndims [[buffer(11)]], const constant int* w_shape [[buffer(12)]], - const constant size_t* w_strides [[buffer(13)]], - const constant size_t* s_strides [[buffer(14)]], - const constant size_t* b_strides [[buffer(15)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], const constant int& batch_ndims [[buffer(16)]], const constant int* batch_shape [[buffer(17)]], const device uint32_t* lhs_indices [[buffer(18)]], const device uint32_t* rhs_indices [[buffer(19)]], - const constant size_t* lhs_strides [[buffer(20)]], - const constant size_t* rhs_strides [[buffer(21)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], diff --git a/mlx/backend/metal/kernels/random.metal b/mlx/backend/metal/kernels/random.metal index f61663f8a..ccbd464d3 100644 --- a/mlx/backend/metal/kernels/random.metal +++ b/mlx/backend/metal/kernels/random.metal @@ -71,7 +71,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) { constant const uint& bytes_per_key, constant const int& ndim, constant const int* key_shape, - constant const size_t* key_strides, + constant const int64_t* key_strides, uint2 grid_dim [[threads_per_grid]], uint2 index [[thread_position_in_grid]]) { auto kidx = 2 * index.x; diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index 4af18c970..dc2ce157c 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -59,10 +59,10 @@ instantiate_init_min_max(max, Max) itype, otype, op, uint, dim) \ instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \ col_reduce_small, \ - itype, otype, op, size_t, dim) \ + itype, otype, op, int64_t, dim) \ instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \ col_reduce_longcolumn, \ - itype, otype, op, size_t, dim) + itype, otype, op, int64_t, dim) #define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \ instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \ @@ -70,7 +70,7 @@ instantiate_init_min_max(max, Max) itype, otype, op, uint, dim, bm, bn) \ instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \ col_reduce_looped, \ - itype, otype, op, size_t, dim, bm, bn) + itype, otype, op, int64_t, dim, bm, bn) #define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \ instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \ @@ -78,7 +78,7 @@ instantiate_init_min_max(max, Max) itype, otype, op, uint, dim, bm, bn) \ instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \ col_reduce_2pass, \ - itype, otype, op, size_t, dim, bm, bn) + itype, otype, op, int64_t, dim, bm, bn) #define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \ @@ -98,7 +98,7 @@ instantiate_init_min_max(max, Max) itype, otype, op, uint, dim) \ instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \ row_reduce_small, \ - itype, otype, op, size_t, dim) + itype, otype, op, int64_t, dim) #define instantiate_row_reduce_looped(name, itype, otype, op, dim) \ instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \ @@ -106,7 +106,7 @@ instantiate_init_min_max(max, Max) itype, otype, op, uint, dim) \ instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \ row_reduce_looped, \ - itype, otype, op, size_t, dim) + itype, otype, op, int64_t, dim) #define instantiate_row_reduce_general(name, itype, otype, op) \ instantiate_row_reduce_small(name, itype, otype, op, 1) \ @@ -125,7 +125,7 @@ instantiate_init_min_max(max, Max) instantiate_col_reduce_general(name##tname, itype, otype, op) #define instantiate_and_or(name, op) \ - instantiate_reduce_functions(name, bool_, bool, bool, op) \ + instantiate_reduce_functions(name, bool_, bool, bool, op) \ instantiate_reduce_functions(name, int16, int16_t, bool, op) \ instantiate_reduce_functions(name, int32, int32_t, bool, op) \ instantiate_reduce_functions(name, int64, int64_t, bool, op) diff --git a/mlx/backend/metal/kernels/reduction/reduce_col.h b/mlx/backend/metal/kernels/reduction/reduce_col.h index 2fa5132d9..c109faf0b 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_col.h +++ b/mlx/backend/metal/kernels/reduction/reduce_col.h @@ -5,12 +5,12 @@ template const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], + const constant int64_t& reduction_stride [[buffer(3)]], const constant int* shape [[buffer(4)]], - const constant size_t* strides [[buffer(5)]], + const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], - const constant size_t* reduce_strides [[buffer(8)]], + const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], const constant size_t& non_col_reductions [[buffer(10)]], uint3 gid [[threadgroup_position_in_grid]], @@ -34,7 +34,7 @@ template bool safe = column + n_reads <= reduction_stride; IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); - IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + column; IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); @@ -100,10 +100,10 @@ template const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], const constant int* shape [[buffer(4)]], - const constant size_t* strides [[buffer(5)]], + const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], - const constant size_t* reduce_strides [[buffer(8)]], + const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], const constant size_t& non_col_reductions [[buffer(10)]], const constant size_t& out_size [[buffer(11)]], @@ -116,7 +116,7 @@ template const device T* row; IdxT out_idx = gid.x + gsize.x * IdxT(gid.y); - IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + lid.x; U total = Op::init; @@ -164,12 +164,12 @@ template < const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], + const constant int64_t& reduction_stride [[buffer(3)]], const constant int* shape [[buffer(4)]], - const constant size_t* strides [[buffer(5)]], + const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], - const constant size_t* reduce_strides [[buffer(8)]], + const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], const constant size_t& non_col_reductions [[buffer(10)]], uint3 gid [[threadgroup_position_in_grid]], @@ -197,7 +197,7 @@ template < bool safe = column + n_reads <= reduction_stride; IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); - IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + column; IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); @@ -303,12 +303,12 @@ template < const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], + const constant int64_t& reduction_stride [[buffer(3)]], const constant int* shape [[buffer(4)]], - const constant size_t* strides [[buffer(5)]], + const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], - const constant size_t* reduce_strides [[buffer(8)]], + const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], const constant size_t& non_col_reductions [[buffer(10)]], const constant size_t& out_size [[buffer(11)]], @@ -342,7 +342,7 @@ template < IdxT full_idx = gid.y + gsize.y * IdxT(gid.z); IdxT block_idx = full_idx / IdxT(out_size); IdxT out_idx = full_idx % IdxT(out_size); - IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); in += in_idx + column; IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); diff --git a/mlx/backend/metal/kernels/reduction/reduce_row.h b/mlx/backend/metal/kernels/reduction/reduce_row.h index 746361255..c8973429f 100644 --- a/mlx/backend/metal/kernels/reduction/reduce_row.h +++ b/mlx/backend/metal/kernels/reduction/reduce_row.h @@ -98,11 +98,11 @@ template < METAL_FUNC void per_thread_row_reduce( thread U totals[N_WRITES], const device T* in, - const size_t row_idx, + const int64_t row_idx, int blocks, int extra, const constant int* shape, - const constant size_t* strides, + const constant int64_t* strides, const constant int& ndim, uint lsize_x, uint lid_x) { @@ -199,13 +199,13 @@ template < [[kernel]] void row_reduce_small( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], - const constant size_t& row_size [[buffer(2)]], - const constant size_t& non_row_reductions [[buffer(3)]], + const constant int64_t& row_size [[buffer(2)]], + const constant int64_t& non_row_reductions [[buffer(3)]], const constant int* shape [[buffer(4)]], - const constant size_t* strides [[buffer(5)]], + const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], - const constant size_t* reduce_strides [[buffer(8)]], + const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint3 gid [[threadgroup_position_in_grid]], @@ -225,7 +225,7 @@ template < if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { // Simple loop over non_row_reductions and reduce the row in the thread. IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); - in += elem_to_loc(out_idx, shape, strides, ndim); + in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { row = in + loop.location(); @@ -238,7 +238,7 @@ template < // Collaboratively reduce over non_row_reductions in the simdgroup. Each // thread reduces every 32nd row and then a simple simd reduce. IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); - in += elem_to_loc(out_idx, shape, strides, ndim); + in += elem_to_loc(out_idx, shape, strides, ndim); loop.next(simd_lane_id, reduce_shape, reduce_strides); @@ -260,14 +260,14 @@ template < typename T, typename U, typename Op, - typename IdxT = size_t, + typename IdxT = int64_t, int N_READS = REDUCE_N_READS, int N_WRITES = REDUCE_N_WRITES> [[kernel]] void row_reduce_simple( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& out_size [[buffer(3)]], + const constant int64_t& out_size [[buffer(3)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], uint3 lid [[thread_position_in_threadgroup]], @@ -314,13 +314,13 @@ template < [[kernel]] void row_reduce_looped( const device T* in [[buffer(0)]], device U* out [[buffer(1)]], - const constant size_t& row_size [[buffer(2)]], - const constant size_t& non_row_reductions [[buffer(3)]], + const constant int64_t& row_size [[buffer(2)]], + const constant int64_t& non_row_reductions [[buffer(3)]], const constant int* shape [[buffer(4)]], - const constant size_t* strides [[buffer(5)]], + const constant int64_t* strides [[buffer(5)]], const constant int& ndim [[buffer(6)]], const constant int* reduce_shape [[buffer(7)]], - const constant size_t* reduce_strides [[buffer(8)]], + const constant int64_t* reduce_strides [[buffer(8)]], const constant int& reduce_ndim [[buffer(9)]], uint3 gid [[threadgroup_position_in_grid]], uint3 gsize [[threadgroups_per_grid]], @@ -337,8 +337,7 @@ template < // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it // needs a small refactor. - in += elem_to_loc(out_idx, shape, strides, ndim) + - lid.x * N_READS; + in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; LoopedElemToLoc 2)> loop(reduce_ndim); const device T* row; diff --git a/mlx/backend/metal/kernels/scatter.h b/mlx/backend/metal/kernels/scatter.h index 63b09df3d..d96eca3db 100644 --- a/mlx/backend/metal/kernels/scatter.h +++ b/mlx/backend/metal/kernels/scatter.h @@ -16,11 +16,11 @@ METAL_FUNC void scatter_impl( const device T* updates, device mlx_atomic* out, const constant int* upd_shape, - const constant size_t* upd_strides, + const constant int64_t* upd_strides, const constant size_t& upd_ndim, const constant size_t& upd_size, const constant int* out_shape, - const constant size_t* out_strides, + const constant int64_t* out_strides, const constant size_t& out_ndim, const constant int* axes, const constant size_t& idx_size, @@ -31,7 +31,7 @@ METAL_FUNC void scatter_impl( auto ind_idx = gid.y * NWORK; LocT out_offset = 0; if (upd_size > 1) { - out_offset = elem_to_loc( + out_offset = elem_to_loc( gid.x, upd_shape + indices.ndim, out_strides, out_ndim); } @@ -40,7 +40,7 @@ METAL_FUNC void scatter_impl( for (int i = 0; i < NIDX; ++i) { auto idx_loc = indices.row_contiguous[i] ? ind_idx - : elem_to_loc( + : elem_to_loc( ind_idx, &indices.shapes[indices.ndim * i], &indices.strides[indices.ndim * i], @@ -52,8 +52,7 @@ METAL_FUNC void scatter_impl( } auto upd_idx = ind_idx * static_cast(upd_size) + gid.x; if constexpr (!UPD_ROW_CONTIG) { - upd_idx = - elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); + upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); } op.atomic_update(out, updates[upd_idx], out_idx); } diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index bfa4d98ea..e9a00c777 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -343,8 +343,8 @@ template < const constant int& out_stride_sorted_axis [[buffer(4)]], const constant int& nc_dim [[buffer(5)]], const constant int* nc_shape [[buffer(6)]], - const constant size_t* in_nc_strides [[buffer(7)]], - const constant size_t* out_nc_strides [[buffer(8)]], + const constant int64_t* in_nc_strides [[buffer(7)]], + const constant int64_t* out_nc_strides [[buffer(8)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = @@ -486,7 +486,7 @@ template < const constant int& stride_sorted_axis [[buffer(4)]], const constant int& nc_dim [[buffer(5)]], const constant int* nc_shape [[buffer(6)]], - const constant size_t* nc_strides [[buffer(7)]], + const constant int64_t* nc_strides [[buffer(7)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { using sort_kernel = KernelMultiBlockMergeSort< diff --git a/mlx/backend/metal/kernels/steel/attn/params.h b/mlx/backend/metal/kernels/steel/attn/params.h index a9d7c7b4a..4f9680412 100644 --- a/mlx/backend/metal/kernels/steel/attn/params.h +++ b/mlx/backend/metal/kernels/steel/attn/params.h @@ -26,10 +26,10 @@ struct AttnParams { int NQ_aligned; ///< Number of full query blocks int NK_aligned; ///< Number of full key/value blocks - size_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) - size_t K_strides[3]; ///< Key strides (B, H, L, D = 1) - size_t V_strides[3]; ///< Value strides (B, H, L, D = 1) - size_t O_strides[3]; ///< Output strides (B, H, L, D = 1) + int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) + int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) + int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) + int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) }; } // namespace steel diff --git a/mlx/backend/metal/kernels/steel/conv/params.h b/mlx/backend/metal/kernels/steel/conv/params.h index f75851dc8..61b8474f8 100644 --- a/mlx/backend/metal/kernels/steel/conv/params.h +++ b/mlx/backend/metal/kernels/steel/conv/params.h @@ -14,9 +14,9 @@ struct MLXConvParams { const int pad[NDIM]; // Input padding const int kdil[NDIM]; // Kernel dilation const int idil[NDIM]; // Input dilation - const size_t in_strides[NDIM + 2]; // In strides - const size_t wt_strides[NDIM + 2]; // Wt strides - const size_t out_strides[NDIM + 2]; // Out strides + const int64_t in_strides[NDIM + 2]; // In strides + const int64_t wt_strides[NDIM + 2]; // Wt strides + const int64_t out_strides[NDIM + 2]; // Out strides const int groups; // Input channel groups const bool flip; }; @@ -59,4 +59,4 @@ struct Conv2DGeneralBaseInfo { }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h index 5e1d2f231..bcc585bbe 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -38,12 +38,12 @@ template < const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], + const constant int64_t* batch_strides [[buffer(7)]], const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], - const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], + const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]], const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], @@ -88,9 +88,8 @@ template < uint32_t indx_A, indx_B, indx_C; if (has_batch) { - const constant size_t* indx_A_bstrides = batch_strides; - const constant size_t* indx_B_bstrides = - batch_strides + params->batch_ndim; + const constant auto* indx_A_bstrides = batch_strides; + const constant auto* indx_B_bstrides = batch_strides + params->batch_ndim; ulong2 indx_offsets = elem_to_loc_broadcast( tid.z, @@ -102,7 +101,7 @@ template < indx_B = rhs_indices[indx_offsets.y]; if (use_out_source) { - const constant size_t* indx_C_bstrides = + const constant auto* indx_C_bstrides = indx_B_bstrides + params->batch_ndim; auto indx_offset_C = elem_to_loc( tid.z, batch_shape, indx_C_bstrides, params->batch_ndim); @@ -120,18 +119,18 @@ template < // Translate indices to offsets int batch_ndim_A = operand_batch_ndim.x; const constant int* batch_shape_A = operand_shape; - const constant size_t* batch_strides_A = operand_strides; + const constant auto* batch_strides_A = operand_strides; A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); int batch_ndim_B = operand_batch_ndim.y; const constant int* batch_shape_B = batch_shape_A + batch_ndim_A; - const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A; + const constant auto* batch_strides_B = batch_strides_A + batch_ndim_A; B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); if (use_out_source) { int batch_ndim_C = operand_batch_ndim.z; const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; - const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B; + const constant auto* batch_strides_C = batch_strides_B + batch_ndim_B; C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C); } @@ -140,8 +139,8 @@ template < // Handle regular batch else { if (has_batch) { - const constant size_t* A_bstrides = batch_strides; - const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); @@ -150,7 +149,7 @@ template < B += batch_offsets.y; if (use_out_source) { - const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; + const constant auto* C_bstrides = B_bstrides + params->batch_ndim; C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); } } else { diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal index 4333be26c..e03d4beb2 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal @@ -7,26 +7,10 @@ #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h" #define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ - template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \ - [[kernel]] void gemm( \ - const device itype *A [[buffer(0)]], \ - const device itype *B [[buffer(1)]], \ - const device itype *C [[buffer(2), function_constant(use_out_source)]], \ - device itype *D [[buffer(3)]], \ - const constant GEMMParams* params [[buffer(4)]], \ - const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ - const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], \ - const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], \ - const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], \ - const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], \ - const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], \ - const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); + instantiate_kernel( \ + "steel_gemm_fused_" #tname "_" #iname "_" #oname \ + "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \ + gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float) #define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h index 702e13152..c8ffe2b8e 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h @@ -56,7 +56,7 @@ block_masked_gemm( device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], + const constant int64_t* batch_strides [[buffer(7)]], const device out_mask_t* out_mask [[buffer(10)]], const device op_mask_t* lhs_mask [[buffer(11)]], const device op_mask_t* rhs_mask [[buffer(12)]], @@ -104,7 +104,7 @@ block_masked_gemm( return; } - const constant size_t* mask_batch_strides = + const constant auto* mask_batch_strides = batch_strides + 2 * params->batch_ndim; if (params->batch_ndim > 1) { @@ -116,8 +116,8 @@ block_masked_gemm( } if (has_operand_mask) { - const constant size_t* mask_strides_lhs = mask_batch_strides; - const constant size_t* mask_strides_rhs = + const constant auto* mask_strides_lhs = mask_batch_strides; + const constant auto* mask_strides_rhs = mask_strides_lhs + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( @@ -144,8 +144,8 @@ block_masked_gemm( // Adjust for batch if (params->batch_ndim > 1) { - const constant size_t* A_bstrides = batch_strides; - const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); @@ -442,7 +442,7 @@ block_masked_gemm( device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], + const constant int64_t* batch_strides [[buffer(7)]], const device bool* out_mask [[buffer(10)]], const device bool* lhs_mask [[buffer(11)]], const device bool* rhs_mask [[buffer(12)]], @@ -476,15 +476,15 @@ block_masked_gemm( } if (params->batch_ndim > 1) { - const constant size_t* mask_batch_strides = + const constant auto* mask_batch_strides = batch_strides + 2 * params->batch_ndim; out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); if (has_operand_mask) { - const constant size_t* mask_strides_lhs = + const constant auto* mask_strides_lhs = mask_batch_strides + params->batch_ndim; - const constant size_t* mask_strides_rhs = + const constant auto* mask_strides_rhs = mask_strides_lhs + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( @@ -507,8 +507,8 @@ block_masked_gemm( // Adjust for batch if (params->batch_ndim > 1) { - const constant size_t* A_bstrides = batch_strides; - const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; ulong2 batch_offsets = elem_to_loc_broadcast( tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal index c127893ff..af34a6870 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal @@ -5,58 +5,45 @@ #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h" -#define instantiate_gemm( \ - outmaskname, \ - outmasktype, \ - opmaskname, \ - opmasktype, \ - tname, \ - trans_a, \ - trans_b, \ - iname, \ - itype, \ - oname, \ - otype, \ - bm, \ - bn, \ - bk, \ - wm, \ - wn, \ - aname, \ - mn_aligned, \ - kname, \ - k_aligned) \ - template [[host_name("steel_gemm_block_outmask_" #outmaskname \ - "_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \ - "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \ - "_MN_" #aname "_K_" #kname)]] [[kernel]] void \ - block_masked_gemm< \ - itype, \ - outmasktype, \ - opmasktype, \ - bm, \ - bn, \ - bk, \ - wm, \ - wn, \ - trans_a, \ - trans_b, \ - mn_aligned, \ - k_aligned>( \ - const device itype* A [[buffer(0)]], \ - const device itype* B [[buffer(1)]], \ - device itype* D [[buffer(3)]], \ - const constant GEMMParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ - const device outmasktype* out_mask [[buffer(10)]], \ - const device opmasktype* lhs_mask [[buffer(11)]], \ - const device opmasktype* rhs_mask [[buffer(12)]], \ - const constant int* mask_strides [[buffer(13)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); +#define instantiate_gemm( \ + outmaskname, \ + outmasktype, \ + opmaskname, \ + opmasktype, \ + tname, \ + trans_a, \ + trans_b, \ + iname, \ + itype, \ + oname, \ + otype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + aname, \ + mn_aligned, \ + kname, \ + k_aligned) \ + instantiate_kernel( \ + "steel_gemm_block_outmask_" #outmaskname \ + "_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \ + "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \ + "_MN_" #aname "_K_" #kname, \ + block_masked_gemm, \ + itype, \ + outmasktype, \ + opmasktype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + mn_aligned, \ + k_aligned) #define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal index 739e3f30e..a046515b0 100644 --- a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal @@ -5,46 +5,39 @@ #include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h" -#define instantiate_gemm( \ - tname, \ - trans_a, \ - trans_b, \ - iname, \ - itype, \ - oname, \ - otype, \ - bm, \ - bn, \ - bk, \ - wm, \ - wn, \ - aname, \ - mn_aligned, \ - kname, \ - k_aligned) \ - template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname \ - "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \ - "_MN_" #aname "_K_" #kname)]] [[kernel]] void \ - gemm_splitk< \ - itype, \ - otype, \ - bm, \ - bn, \ - bk, \ - wm, \ - wn, \ - trans_a, \ - trans_b, \ - mn_aligned, \ - k_aligned>( \ - const device itype* A [[buffer(0)]], \ - const device itype* B [[buffer(1)]], \ - device otype* C [[buffer(2)]], \ - const constant GEMMSpiltKParams* params [[buffer(3)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); +#define instantiate_gemm( \ + tname, \ + trans_a, \ + trans_b, \ + iname, \ + itype, \ + oname, \ + otype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + aname, \ + mn_aligned, \ + kname, \ + k_aligned) \ + instantiate_kernel( \ + "steel_gemm_splitk_" #tname "_" #iname "_" #oname \ + "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \ + "_MN_" #aname "_K_" #kname, \ + gemm_splitk, \ + itype, \ + otype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + mn_aligned, \ + k_aligned) #define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ @@ -68,30 +61,13 @@ instantiate_gemm_shapes_helper(float16, half, float32, float); instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float); instantiate_gemm_shapes_helper(float32, float, float32, float); -#define instantiate_accum(oname, otype, aname, atype) \ - template [[host_name("steel_gemm_splitk_accum_" #oname \ - "_" #aname)]] [[kernel]] void \ - gemm_splitk_accum( \ - const device atype* C_split [[buffer(0)]], \ - device otype* D [[buffer(1)]], \ - const constant int& k_partitions [[buffer(2)]], \ - const constant int& partition_stride [[buffer(3)]], \ - const constant int& ldd [[buffer(4)]], \ - uint2 gid [[thread_position_in_grid]]); \ - template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \ - "_axbpy")]] [[kernel]] void \ - gemm_splitk_accum_axpby( \ - const device atype* C_split [[buffer(0)]], \ - device otype* D [[buffer(1)]], \ - const constant int& k_partitions [[buffer(2)]], \ - const constant int& partition_stride [[buffer(3)]], \ - const constant int& ldd [[buffer(4)]], \ - const device otype* C [[buffer(5)]], \ - const constant int& ldc [[buffer(6)]], \ - const constant int& fdc [[buffer(7)]], \ - const constant float& alpha [[buffer(8)]], \ - const constant float& beta [[buffer(9)]], \ - uint2 gid [[thread_position_in_grid]]); +#define instantiate_accum(oname, otype, aname, atype) \ + instantiate_kernel( \ + "steel_gemm_splitk_accum_" #oname "_" #aname, \ + gemm_splitk_accum, atype, otype) \ + instantiate_kernel( \ + "steel_gemm_splitk_accum_" #oname "_" #aname "_axbpy", \ + gemm_splitk_accum_axpby, atype, otype) \ instantiate_accum(bfloat16, bfloat16_t, float32, float); instantiate_accum(float16, half, float32, float); diff --git a/mlx/backend/metal/kernels/steel/gemm/params.h b/mlx/backend/metal/kernels/steel/gemm/params.h index e8bcb2217..3cb7bdc30 100644 --- a/mlx/backend/metal/kernels/steel/gemm/params.h +++ b/mlx/backend/metal/kernels/steel/gemm/params.h @@ -21,9 +21,9 @@ struct GEMMParams { const int tiles_n; const int tiles_m; - const size_t batch_stride_a; - const size_t batch_stride_b; - const size_t batch_stride_d; + const int64_t batch_stride_a; + const int64_t batch_stride_b; + const int64_t batch_stride_d; const int swizzle_log; const int gemm_k_iterations_aligned; @@ -54,7 +54,7 @@ struct GEMMAddMMParams { const int ldc; const int fdc; - const size_t batch_stride_c; + const int64_t batch_stride_c; const float alpha; const float beta; diff --git a/mlx/backend/metal/kernels/steel/utils.h b/mlx/backend/metal/kernels/steel/utils.h index 322b22503..55720a28f 100644 --- a/mlx/backend/metal/kernels/steel/utils.h +++ b/mlx/backend/metal/kernels/steel/utils.h @@ -7,8 +7,8 @@ METAL_FUNC ulong2 elem_to_loc_broadcast( uint elem, constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, + constant const int64_t* a_strides, + constant const int64_t* b_strides, int ndim) { ulong loc_a{0}; ulong loc_b{0}; @@ -24,9 +24,9 @@ METAL_FUNC ulong2 elem_to_loc_broadcast( METAL_FUNC ulong3 elem_to_loc_broadcast( uint elem, constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const size_t* c_strides, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int64_t* c_strides, int ndim) { ulong loc_a{0}; ulong loc_b{0}; diff --git a/mlx/backend/metal/kernels/ternary.h b/mlx/backend/metal/kernels/ternary.h index 3cf776ad5..4b3adcc80 100644 --- a/mlx/backend/metal/kernels/ternary.h +++ b/mlx/backend/metal/kernels/ternary.h @@ -18,72 +18,72 @@ template device T* d, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - size_t offset = index.x + grid_dim.x * size_t(index.y); + auto offset = index.x + grid_dim.x * int64_t(index.y); d[offset] = Op()(a[offset], b[offset], c[offset]); } -template +template [[kernel]] void ternary_g_nd1( device const bool* a, device const T* b, device const T* c, device T* d, - constant const size_t& a_strides, - constant const size_t& b_strides, - constant const size_t& c_strides, + constant const int64_t& a_strides, + constant const int64_t& b_strides, + constant const int64_t& c_strides, uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_strides); - auto b_idx = elem_to_loc_1(index, b_strides); - auto c_idx = elem_to_loc_1(index, c_strides); + auto a_idx = elem_to_loc_1(index, a_strides); + auto b_idx = elem_to_loc_1(index, b_strides); + auto c_idx = elem_to_loc_1(index, c_strides); d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g_nd2( device const bool* a, device const T* b, device const T* c, device T* d, - constant const size_t a_strides[2], - constant const size_t b_strides[2], - constant const size_t c_strides[2], + constant const int64_t a_strides[2], + constant const int64_t b_strides[2], + constant const int64_t c_strides[2], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - auto c_idx = elem_to_loc_2(index, c_strides); + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + auto c_idx = elem_to_loc_2(index, c_strides); IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g_nd3( device const bool* a, device const T* b, device const T* c, device T* d, - constant const size_t a_strides[3], - constant const size_t b_strides[3], - constant const size_t c_strides[3], + constant const int64_t a_strides[3], + constant const int64_t b_strides[3], + constant const int64_t c_strides[3], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - auto c_idx = elem_to_loc_3(index, c_strides); + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + auto c_idx = elem_to_loc_3(index, c_strides); IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); } -template +template [[kernel]] void ternary_g( device const bool* a, device const T* b, device const T* c, device T* d, constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const size_t* c_strides, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int64_t* c_strides, constant const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { diff --git a/mlx/backend/metal/kernels/unary.h b/mlx/backend/metal/kernels/unary.h index acfe176ef..69828599f 100644 --- a/mlx/backend/metal/kernels/unary.h +++ b/mlx/backend/metal/kernels/unary.h @@ -14,7 +14,7 @@ template device U* out, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - size_t offset = index.x + grid_dim.x * size_t(index.y); + auto offset = index.x + grid_dim.x * int64_t(index.y); out[offset] = Op()(in[offset]); } @@ -23,16 +23,16 @@ template < typename U, typename Op, int N = 1, - typename IdxT = size_t> + typename IdxT = int64_t> [[kernel]] void unary_g( device const T* in, device U* out, constant const int* in_shape, - constant const size_t* in_strides, + constant const int64_t* in_strides, device const int& ndim, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc( + auto idx = elem_to_loc( {N * index.x, index.y, index.z}, in_shape, in_strides, ndim); auto xshape = in_shape[ndim - 1]; IdxT xstride = in_strides[ndim - 1]; diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index c6add37f9..d96656075 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -89,11 +89,11 @@ struct Limits { /////////////////////////////////////////////////////////////////////////////// // Single Array with generic dims -template +template METAL_FUNC IdxT elem_to_loc( uint elem, constant const int* shape, - constant const StrideT* strides, + constant const int64_t* strides, int ndim) { IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { @@ -103,11 +103,11 @@ METAL_FUNC IdxT elem_to_loc( return loc; } -template +template METAL_FUNC IdxT elem_to_loc( - StrideT elem, + int64_t elem, constant const int* shape, - constant const StrideT* strides, + constant const int64_t* strides, int ndim) { IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { @@ -118,11 +118,11 @@ METAL_FUNC IdxT elem_to_loc( } // Non templated version to handle arbitrary dims -template +template METAL_FUNC IdxT elem_to_loc( uint3 elem, constant const int* shape, - constant const StrideT* strides, + constant const int64_t* strides, int ndim) { IdxT loc = elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]); @@ -136,18 +136,18 @@ METAL_FUNC IdxT elem_to_loc( /////////////////////////////////////////////////////////////////////////////// // Single Array with fixed N dims -template -METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) { +template +METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) { return elem * IdxT(stride); } -template -METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) { +template +METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) { return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]); } -template -METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) { +template +METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) { return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) + elem.z * IdxT(strides[0]); } @@ -155,12 +155,12 @@ METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) { /////////////////////////////////////////////////////////////////////////////// // Multiple Arrays with generic dims -template +template METAL_FUNC vec elem_to_loc_2_nd( uint3 elem, constant const int* shape, - constant const StrideT* a_strides, - constant const StrideT* b_strides, + constant const int64_t* a_strides, + constant const int64_t* b_strides, int ndim) { vec loc = { IdxT( @@ -178,13 +178,13 @@ METAL_FUNC vec elem_to_loc_2_nd( return loc; } -template +template METAL_FUNC vec elem_to_loc_3_nd( uint3 elem, constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - constant const size_t* c_strides, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int64_t* c_strides, int ndim) { vec loc = { elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]), @@ -213,7 +213,7 @@ struct LoopedElemToLoc { LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} - void next(const constant int* shape, const constant size_t* strides) { + void next(const constant int* shape, const constant int64_t* strides) { if (dim == 0) { return; } @@ -226,7 +226,7 @@ struct LoopedElemToLoc { } } - void next(int n, const constant int* shape, const constant size_t* strides) { + void next(int n, const constant int* shape, const constant int64_t* strides) { if (dim == 0) { return; } @@ -262,19 +262,19 @@ struct LoopedElemToLoc<1, OffsetT, true> { LoopedElemToLoc(int dim) : dim(dim) {} - void next(const constant int* shape, const constant size_t* strides) { + void next(const constant int* shape, const constant int64_t* strides) { index++; if (dim > 1) { - offset = elem_to_loc(index, shape, strides, dim); + offset = elem_to_loc(index, shape, strides, dim); } else { offset += OffsetT(strides[0]); } } - void next(int n, const constant int* shape, const constant size_t* strides) { + void next(int n, const constant int* shape, const constant int64_t* strides) { index += n; if (dim > 1) { - offset = elem_to_loc(index, shape, strides, dim); + offset = elem_to_loc(index, shape, strides, dim); } else { offset = index * OffsetT(strides[0]); } @@ -291,11 +291,11 @@ struct LoopedElemToLoc<1, OffsetT, false> { LoopedElemToLoc(int) {} - void next(const constant int*, const constant size_t* strides) { + void next(const constant int*, const constant int64_t* strides) { offset += OffsetT(strides[0]); } - void next(int n, const constant int*, const constant size_t* strides) { + void next(int n, const constant int*, const constant int64_t* strides) { offset += n * OffsetT(strides[0]); } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 7d2ccd87f..09ca27a4e 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -21,8 +21,8 @@ namespace { inline auto collapse_batches(const array& a, const array& b) { // Get and check the shape for the batched dims - std::vector A_bshape{a.shape().begin(), a.shape().end() - 2}; - std::vector B_bshape{b.shape().begin(), b.shape().end() - 2}; + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; + Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; if (A_bshape != B_bshape) { std::ostringstream msg; msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A " @@ -30,8 +30,8 @@ inline auto collapse_batches(const array& a, const array& b) { throw std::runtime_error(msg.str()); } - std::vector A_bstride{a.strides().begin(), a.strides().end() - 2}; - std::vector B_bstride{b.strides().begin(), b.strides().end() - 2}; + Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; + Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; auto [batch_shape, batch_strides] = collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride}); @@ -50,9 +50,9 @@ inline auto collapse_batches(const array& a, const array& b) { inline auto collapse_batches(const array& a, const array& b, const array& c) { // Get and check the shape for the batched dims - std::vector A_bshape{a.shape().begin(), a.shape().end() - 2}; - std::vector B_bshape{b.shape().begin(), b.shape().end() - 2}; - std::vector C_bshape{c.shape().begin(), c.shape().end() - 2}; + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; + Shape B_bshape{b.shape().begin(), b.shape().end() - 2}; + Shape C_bshape{c.shape().begin(), c.shape().end() - 2}; if (A_bshape != B_bshape || A_bshape != C_bshape) { std::ostringstream msg; msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A " @@ -60,9 +60,9 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) { throw std::runtime_error(msg.str()); } - std::vector A_bstride{a.strides().begin(), a.strides().end() - 2}; - std::vector B_bstride{b.strides().begin(), b.strides().end() - 2}; - std::vector C_bstride{c.strides().begin(), c.strides().end() - 2}; + Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; + Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; + Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; auto [batch_shape, batch_strides] = collapse_contiguous_dims( A_bshape, std::vector{A_bstride, B_bstride, C_bstride}); @@ -82,6 +82,25 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) { batch_shape, A_batch_stride, B_batch_stride, C_batch_stride); } +std::tuple check_transpose( + std::vector& copies, + const Stream& s, + const array& arr, + bool is_vector) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + return std::make_tuple(false, arr.shape(-1), arr_copy); + } +}; + } // namespace /////////////////////////////////////////////////////////////////////////////// @@ -180,11 +199,11 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, - std::vector batch_shape, - std::vector batch_strides, - size_t A_batch_stride, - size_t B_batch_stride, - size_t matrix_stride_out, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out, std::vector& copies) { using namespace mlx::steel; @@ -268,9 +287,9 @@ void steel_matmul_regular( /* const int ldd = */ ldd, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, - /* const size_t batch_stride_a = */ A_batch_stride, - /* const size_t batch_stride_b = */ B_batch_stride, - /* const size_t batch_stride_d = */ matrix_stride_out, + /* const int64_t batch_stride_a = */ A_batch_stride, + /* const int64_t batch_stride_b = */ B_batch_stride, + /* const int64_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ int(batch_shape.size())}; @@ -314,9 +333,9 @@ void steel_matmul( bool transpose_a, bool transpose_b, std::vector& copies, - std::vector batch_shape /* = {} */, - std::vector A_batch_stride /* = {} */, - std::vector B_batch_stride /* = {} */) { + Shape batch_shape /* = {} */, + Strides A_batch_stride /* = {} */, + Strides B_batch_stride /* = {} */) { using namespace mlx::steel; if (batch_shape.empty()) { @@ -447,7 +466,7 @@ void steel_matmul( ///////////////////////////////////////////////////////////////////////////// // Regular kernel dispatch - std::vector batch_strides = A_batch_stride; + auto batch_strides = A_batch_stride; batch_strides.insert( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); @@ -505,24 +524,8 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; - auto check_transpose = [&copies, &s](const array& arr, bool is_vector) { - auto stx = arr.strides()[arr.ndim() - 2]; - auto sty = arr.strides()[arr.ndim() - 1]; - if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { - return std::make_tuple(false, stx, arr); - } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { - return std::make_tuple(true, sty, arr); - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - size_t stx = arr.shape(-1); - return std::make_tuple(false, stx, arr_copy); - } - }; - - auto [a_transposed, a_cols, a] = check_transpose(a_pre, M == 1); - auto [b_transposed, b_cols, b] = check_transpose(b_pre, N == 1); + auto [a_transposed, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); + auto [b_transposed, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions @@ -662,9 +665,9 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { /* bool transpose_a = */ a_transposed, /* bool transpose_b = */ b_transposed, /* std::vector& = */ copies, - /* std::vector batch_shape = */ batch_shape, - /* std::vector A_batch_stride = */ A_batch_stride, - /* std::vector B_batch_stride = */ B_batch_stride); + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); } void AddMM::eval_gpu(const std::vector& inputs, array& out) { @@ -691,24 +694,8 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; - auto check_transpose = [&copies, &s](const array& arr, bool is_vector) { - auto stx = arr.strides()[arr.ndim() - 2]; - auto sty = arr.strides()[arr.ndim() - 1]; - if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { - return std::make_tuple(false, stx, arr); - } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { - return std::make_tuple(true, sty, arr); - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - size_t stx = arr.shape(-1); - return std::make_tuple(false, stx, arr_copy); - } - }; - - auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1); - auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1); + auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); + auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); array c = c_pre; int ldc = c.strides()[c.ndim() - 2]; @@ -723,7 +710,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] = collapse_batches(a, b, c); - size_t matrix_stride_out = size_t(M) * size_t(N); + int64_t matrix_stride_out = M * static_cast(N); auto batch_size_out = out.size() / (matrix_stride_out); // Collapse batches into M if needed @@ -1044,9 +1031,9 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { /* const int ldd = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, - /* const size_t batch_stride_a = */ A_batch_stride.back(), - /* const size_t batch_stride_b = */ B_batch_stride.back(), - /* const size_t batch_stride_d = */ matrix_stride_out, + /* const int64_t batch_stride_a = */ A_batch_stride.back(), + /* const int64_t batch_stride_b = */ B_batch_stride.back(), + /* const int64_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ int(batch_shape.size())}; @@ -1054,7 +1041,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { GEMMAddMMParams params{ /* const int ldc = */ ldc, /* const int fdc = */ fdc, - /* const size_t batch_stride_c = */ C_batch_stride.back(), + /* const int64_t batch_stride_c = */ C_batch_stride.back(), /* const float alpha = */ alpha_, /* const float beta = */ beta_}; @@ -1065,7 +1052,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); - std::vector batch_strides = A_batch_stride; + Strides batch_strides = A_batch_stride; batch_strides.insert( batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end()); batch_strides.insert( @@ -1120,24 +1107,8 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; - auto check_transpose = [&copies, &s](const array& arr, bool is_vector) { - auto stx = arr.strides()[arr.ndim() - 2]; - auto sty = arr.strides()[arr.ndim() - 1]; - if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { - return std::make_tuple(false, stx, arr); - } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { - return std::make_tuple(true, sty, arr); - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - size_t stx = arr.shape(-1); - return std::make_tuple(false, stx, arr_copy); - } - }; - - auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1); - auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1); + auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); + auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); int lda = a_cols; int ldb = b_cols; @@ -1156,20 +1127,20 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { return decltype(v){v.begin(), v.end() - 2}; }; - std::vector batch_shape{1}; - std::vector A_batch_stride{0}; - std::vector B_batch_stride{0}; - std::vector outmask_bstride{0}; - std::vector Amask_bstride{0}; - std::vector Bmask_bstride{0}; - size_t A_batch_str = 0; - size_t B_batch_str = 0; + Shape batch_shape{1}; + Strides A_batch_stride{0}; + Strides B_batch_stride{0}; + Strides outmask_bstride{0}; + Strides Amask_bstride{0}; + Strides Bmask_bstride{0}; + int64_t A_batch_str = 0; + int64_t B_batch_str = 0; - std::vector batch_strides; + Strides batch_strides; if (out.ndim() > 2) { - std::vector bshape{out.shape().begin(), out.shape().end() - 2}; - std::vector> bstrides; + Shape bshape{out.shape().begin(), out.shape().end() - 2}; + std::vector bstrides; for (auto& arr : inputs) { bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2); @@ -1196,10 +1167,10 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { } } else { - batch_strides = std::vector(inputs.size(), 0); + batch_strides = Strides(inputs.size(), 0); } - size_t matrix_stride_out = size_t(M) * N; + int64_t matrix_stride_out = static_cast(M) * N; size_t batch_size_out = out.size() / (matrix_stride_out); ///////////////////////////////////////////////////////////////////////////// @@ -1306,7 +1277,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { // Get mask params std::vector mask_strides; - std::vector mask_batch_strides; + Strides mask_batch_strides; if (has_out_mask) { auto& out_mask = inputs[2]; @@ -1436,9 +1407,9 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { /* const int ldd = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, - /* const size_t batch_stride_a = */ A_batch_str, - /* const size_t batch_stride_b = */ B_batch_str, - /* const size_t batch_stride_d = */ matrix_stride_out, + /* const int64_t batch_stride_a = */ A_batch_str, + /* const int64_t batch_stride_b = */ B_batch_str, + /* const int64_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ int(batch_shape.size())}; @@ -1524,24 +1495,8 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; - auto check_transpose = [&copies, &s](const array& arr, bool is_vector) { - auto stx = arr.strides()[arr.ndim() - 2]; - auto sty = arr.strides()[arr.ndim() - 1]; - if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { - return std::make_tuple(false, stx, arr); - } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { - return std::make_tuple(true, sty, arr); - } else { - array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); - copy_gpu(arr, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - size_t stx = arr.shape(-1); - return std::make_tuple(false, stx, arr_copy); - } - }; - - auto [transpose_a, a_cols, a] = check_transpose(a_pre, M == 1); - auto [transpose_b, b_cols, b] = check_transpose(b_pre, N == 1); + auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1); + auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1); int lda = a_cols; int ldb = b_cols; @@ -1556,20 +1511,20 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { auto& lhs_indices = inputs[2]; auto& rhs_indices = inputs[3]; - std::vector batch_shape = get_batch_dims(out.shape()); - std::vector batch_strides; + Shape batch_shape = get_batch_dims(out.shape()); + Strides batch_strides; batch_strides.insert( batch_strides.end(), lhs_indices.strides().begin(), lhs_indices.strides().end()); - size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); + auto lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); batch_strides.insert( batch_strides.end(), rhs_indices.strides().begin(), rhs_indices.strides().end()); - size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); + auto rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back(); int batch_ndim = batch_shape.size(); @@ -1582,10 +1537,10 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { int batch_ndim_B = b.ndim() - 2; std::vector operand_batch_ndim = {batch_ndim_A, batch_ndim_B}; - std::vector batch_shape_A = get_batch_dims(a.shape()); - std::vector batch_strides_A = get_batch_dims(a.strides()); - std::vector batch_shape_B = get_batch_dims(b.shape()); - std::vector batch_strides_B = get_batch_dims(b.strides()); + Shape batch_shape_A = get_batch_dims(a.shape()); + Strides batch_strides_A = get_batch_dims(a.strides()); + Shape batch_shape_B = get_batch_dims(b.shape()); + Strides batch_strides_B = get_batch_dims(b.strides()); if (batch_ndim_A == 0) { batch_shape_A = {1}; @@ -1597,7 +1552,7 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { batch_strides_B = {0}; } - size_t matrix_stride_out = size_t(M) * N; + auto matrix_stride_out = static_cast(M) * N; auto batch_size_out = out.size() / matrix_stride_out; ///////////////////////////////////////////////////////////////////////////// @@ -1801,9 +1756,9 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { /* const int ldd = */ N, /* const int tiles_n = */ tn, /* const int tiles_m = */ tm, - /* const size_t batch_stride_a = */ lhs_indices_str, - /* const size_t batch_stride_b = */ rhs_indices_str, - /* const size_t batch_stride_d = */ matrix_stride_out, + /* const int64_t batch_stride_a = */ lhs_indices_str, + /* const int64_t batch_stride_b = */ rhs_indices_str, + /* const int64_t batch_stride_d = */ matrix_stride_out, /* const int swizzle_log = */ swizzle_log, /* const int gemm_k_iterations_aligned = */ (K / bk), /* const int batch_ndim = */ batch_ndim}; diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index b6160d6f1..09ffe05a8 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -21,11 +21,11 @@ void steel_matmul_regular( int ldd, bool transpose_a, bool transpose_b, - std::vector batch_shape, - std::vector batch_strides, - size_t A_batch_stride, - size_t B_batch_stride, - size_t matrix_stride_out, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out, std::vector& copies); void steel_matmul( @@ -43,8 +43,8 @@ void steel_matmul( bool transpose_a, bool transpose_b, std::vector& copies, - std::vector batch_shape = {}, - std::vector A_batch_stride = {}, - std::vector B_batch_stride = {}); + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}); } // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 9d8d3f942..5c86d2b84 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -5,6 +5,7 @@ #include #include "mlx/backend/common/load.h" +#include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" @@ -101,10 +102,10 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { } // Prepare the shapes, strides and axis arguments. - std::vector in_strides = in.strides(); - std::vector shape = in.shape(); - std::vector out_strides = out.strides(); - size_t axis_stride = in_strides[axis_]; + auto in_strides = in.strides(); + auto shape = in.shape(); + auto out_strides = out.strides(); + auto axis_stride = in_strides[axis_]; size_t axis_size = shape[axis_]; if (out_strides.size() == in_strides.size()) { out_strides.erase(out_strides.begin() + axis_); @@ -136,7 +137,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { if (ndim == 0) { // Pass place holders so metal doesn't complain int shape_ = 0; - size_t stride_ = 0; + int64_t stride_ = 0; compute_encoder.set_bytes(shape_, 2); compute_encoder.set_bytes(stride_, 3); compute_encoder.set_bytes(stride_, 4); @@ -311,13 +312,12 @@ void Reshape::eval_gpu(const std::vector& inputs, array& out) { if (copy_necessary) { out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto out_strides = make_contiguous_strides(in.shape()); copy_gpu_inplace( in, out, in.shape(), in.strides(), - out_strides, + make_contiguous_strides(in.shape()), 0, 0, CopyType::General, @@ -366,16 +366,15 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); // Calculate out strides, initial offset and if copy needs to be made - auto [data_offset, out_strides] = prepare_slice(out); + auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_); // Do copy - std::vector upd_strides{upd.strides().begin(), upd.strides().end()}; - copy_gpu_inplace( + copy_gpu_inplace( /* const array& src = */ upd, /* array& dst = */ out, - /* const std::vector& data_shape = */ upd.shape(), - /* const std::vector& i_strides = */ upd_strides, - /* const std::vector& o_strides = */ out_strides, + /* const Shape& data_shape = */ upd.shape(), + /* const Strides& i_strides = */ upd.strides(), + /* const Strides& o_strides = */ out_strides, /* int64_t i_offset = */ 0, /* int64_t o_offset = */ data_offset, /* CopyType ctype = */ CopyType::GeneralGeneral, diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 15cbcc9af..6720051d5 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -18,13 +18,13 @@ namespace { struct RowReduceArgs { // Input shape and strides not including the reduction axes - std::vector shape; - std::vector strides; + Shape shape; + Strides strides; int ndim; // Input shape and strides for the reduction axes - std::vector reduce_shape; - std::vector reduce_strides; + Shape reduce_shape; + Strides reduce_strides; int reduce_ndim; // The number of rows we are reducing. Namely prod(reduce_shape). @@ -88,13 +88,13 @@ struct RowReduceArgs { struct ColReduceArgs { // Input shape and strides not including the reduction axes - std::vector shape; - std::vector strides; + Shape shape; + Strides strides; int ndim; // Input shape and strides for the reduction axes - std::vector reduce_shape; - std::vector reduce_strides; + Shape reduce_shape; + Strides reduce_strides; int reduce_ndim; // The number of column reductions we are doing. Namely prod(reduce_shape). @@ -102,7 +102,7 @@ struct ColReduceArgs { // The size of the contiguous column reduction. size_t reduction_size; - size_t reduction_stride; + int64_t reduction_stride; ColReduceArgs( const array& in, @@ -126,7 +126,7 @@ struct ColReduceArgs { // yet we may have removed the appropriate amount of elements. It is safe // to compute the stride by multiplying shapes (while < reduction_stride) // because it is a contiguous section. - size_t stride_back = 1; + int64_t stride_back = 1; std::tie(shape, strides) = shapes_without_reduction_axes(in, axes); while (!shape.empty() && stride_back < reduction_stride) { stride_back *= shape.back(); @@ -683,7 +683,7 @@ void strided_reduce_longcolumn( op_name, in_type, out_type, - large ? "size_t" : "uint", + large ? "int64_t" : "uint", n); compute_encoder.set_compute_pipeline_state(kernel); @@ -718,7 +718,7 @@ void strided_reduce_longcolumn( op_name, intermediate.dtype(), out_type, - large ? "size_t" : "uint", + large ? "int64_t" : "uint", 1, 32, 32); @@ -782,7 +782,7 @@ void strided_reduce_looped( op_name, in_type, out_type, - large ? "size_t" : "uint", + large ? "int64_t" : "uint", n, BM, BN); @@ -859,7 +859,7 @@ void strided_reduce_2pass( op_name, in_type, out_type, - large ? "size_t" : "uint", + large ? "int64_t" : "uint", n, BM, BN); @@ -894,7 +894,7 @@ void strided_reduce_2pass( op_name, intermediate.dtype(), out_type, - large ? "size_t" : "uint", + large ? "int64_t" : "uint", 1, 32, 32); diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index f600a4890..db5abbf90 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -50,17 +50,17 @@ void sdpa_full_self_attention_metal( std::ostringstream kname; // clang-format off - kname << "steel_attention_" - << type_to_name(q) - << "_bq" << bq + kname << "steel_attention_" + << type_to_name(q) + << "_bq" << bq << "_bk" << bk - << "_bd" << bd + << "_bd" << bd << "_wm" << wm << "_wn" << wn; // clang-format on std::string base_name = kname.str(); // clang-format off - kname << "_align_Q_" << (align_Q ? 't' : 'n') + kname << "_align_Q_" << (align_Q ? 't' : 'n') << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -92,10 +92,10 @@ void sdpa_full_self_attention_metal( /* int NQ_aligned = */ NQ_aligned, /* int NK_aligned = */ NK_aligned, - /* size_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, - /* size_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, - /* size_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, - /* size_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; + /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, + /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, + /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, + /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; compute_encoder.set_input_array(q, 0); compute_encoder.set_input_array(k, 1); @@ -175,13 +175,13 @@ void sdpa_vector_2pass( int N = k.shape(2); int blocks = 32; int B = q.shape(0) * q.shape(1); - size_t k_stride = k.strides()[1]; - size_t v_stride = v.strides()[1]; + auto k_stride = k.strides()[1]; + auto v_stride = v.strides()[1]; MTL::Size group_dims(8 * 32, 1, 1); MTL::Size grid_dims(1, B, blocks); // Allocate the intermediates - std::vector intermediate_shape; + Shape intermediate_shape; intermediate_shape.reserve(out.ndim() + 1); intermediate_shape.insert( intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1); @@ -324,10 +324,10 @@ void ScaledDotProductAttention::eval_gpu( const auto& k = copy_unless(is_matrix_contiguous, k_pre); const auto& v = copy_unless(is_matrix_contiguous, v_pre); - size_t str_oD = 1; - size_t str_oH = o.shape(3); - size_t str_oL = o.shape(1) * str_oH; - size_t str_oB = o.shape(2) * str_oL; + int64_t str_oD = 1; + int64_t str_oH = o.shape(3); + int64_t str_oL = o.shape(1) * str_oH; + int64_t str_oB = o.shape(2) * str_oL; size_t data_size = o.shape(0) * str_oB; array::Flags flags{ diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index a95337fd5..f2241f607 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -11,29 +11,28 @@ namespace mlx::core { void slice_gpu( const array& in, array& out, - const std::vector& start_indices, - const std::vector& strides, + const Shape& start_indices, + const Shape& strides, const Stream& s) { // Calculate out strides, initial offset and if copy needs to be made - auto [copy_needed, data_offset, inp_strides] = - prepare_slice(in, start_indices, strides); + auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides); + auto copy_needed = + std::any_of(strides.begin(), strides.end(), [](auto i) { return i < 0; }); // Do copy if needed if (copy_needed) { out.set_data(allocator::malloc_or_wait(out.nbytes())); - std::vector ostrides{out.strides().begin(), out.strides().end()}; copy_gpu_inplace( /* const array& in = */ in, /* array& out = */ out, /* const std::vector& data_shape = */ out.shape(), /* const std::vector& i_strides = */ inp_strides, - /* const std::vector& o_strides = */ ostrides, + /* const std::vector& o_strides = */ out.strides(), /* int64_t i_offset = */ data_offset, /* int64_t o_offset = */ 0, /* CopyType ctype = */ CopyType::General, /* const Stream& s = */ s); } else { - std::vector ostrides{inp_strides.begin(), inp_strides.end()}; size_t data_end = 1; for (int i = 0; i < strides.size(); ++i) { if (in.shape()[i] > 1) { @@ -42,7 +41,7 @@ void slice_gpu( } } size_t data_size = data_end - data_offset; - shared_buffer_slice(in, ostrides, data_offset, data_size, out); + shared_buffer_slice(in, inp_strides, data_offset, data_size, out); } } diff --git a/mlx/backend/metal/slicing.h b/mlx/backend/metal/slicing.h index 51da8b54c..5c62b7b73 100644 --- a/mlx/backend/metal/slicing.h +++ b/mlx/backend/metal/slicing.h @@ -9,8 +9,8 @@ namespace mlx::core { void slice_gpu( const array& in, array& out, - const std::vector& start_indices, - const std::vector& strides, + const Shape& start_indices, + const Shape& strides, const Stream& s); void concatenate_gpu( diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index d0d28e20c..91d074c6b 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -24,13 +24,13 @@ void single_block_sort( // Prepare shapes int n_rows = in.size() / in.shape(axis); - std::vector in_nc_str = in.strides(); + auto in_nc_str = in.strides(); in_nc_str.erase(in_nc_str.begin() + axis); - std::vector out_nc_str = out.strides(); + auto out_nc_str = out.strides(); out_nc_str.erase(out_nc_str.begin() + axis); - std::vector nc_shape = in.shape(); + auto nc_shape = in.shape(); nc_shape.erase(nc_shape.begin() + axis); int nc_dim = nc_shape.size(); @@ -106,10 +106,10 @@ void multi_block_sort( // Prepare shapes int n_rows = in.size() / in.shape(axis); - std::vector nc_str = in.strides(); + auto nc_str = in.strides(); nc_str.erase(nc_str.begin() + axis); - std::vector nc_shape = in.shape(); + auto nc_shape = in.shape(); nc_shape.erase(nc_shape.begin() + axis); int nc_dim = nc_shape.size(); diff --git a/mlx/backend/metal/ternary.cpp b/mlx/backend/metal/ternary.cpp index f81ea9240..d44a5151e 100644 --- a/mlx/backend/metal/ternary.cpp +++ b/mlx/backend/metal/ternary.cpp @@ -30,8 +30,8 @@ void ternary_op_gpu_inplace( return std::make_tuple( shape, strides[0], strides[1], strides[2], strides[3]); } else { - std::vector e; - return std::make_tuple(std::vector{}, e, e, e, e); + Strides e; + return std::make_tuple(Shape{}, e, e, e, e); } }; auto [shape, strides_a, strides_b, strides_c, strides_out] = maybe_collapse(); diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index d1d833dc7..ce7dc969b 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -30,7 +30,7 @@ void unary_op_gpu_inplace( if (!contig) { return collapse_contiguous_dims(in); } else { - return std::make_pair(std::vector{}, std::vector{}); + return std::make_pair(Shape{}, Strides{}); } }; auto [shape, strides] = maybe_collapse(); diff --git a/mlx/backend/metal/utils.cpp b/mlx/backend/metal/utils.cpp index 22beb5d43..2eaacba4c 100644 --- a/mlx/backend/metal/utils.cpp +++ b/mlx/backend/metal/utils.cpp @@ -87,9 +87,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) { return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; } -MTL::Size get_2d_grid_dims( - const std::vector& shape, - const std::vector& strides) { +MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) { // Dims with strides of 0 are ignored as they // correspond to broadcasted dimensions size_t grid_x = 1; @@ -114,10 +112,8 @@ MTL::Size get_2d_grid_dims( static_cast(grid_x), static_cast(grid_y), 1); } -MTL::Size get_2d_grid_dims( - const std::vector& shape, - const std::vector& strides, - size_t divisor) { +MTL::Size +get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) { // Compute the 2d grid dimensions such that the total size of the grid is // divided by divisor. size_t grid_x = 1; diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index 082e4d116..cc56bab32 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -22,17 +22,13 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); // - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 // - shape and strides correspond to a contiguous (no holes) but // possibly broadcasted array -MTL::Size get_2d_grid_dims( - const std::vector& shape, - const std::vector& strides); +MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides); // Same as above but we do an implicit division with divisor. // Basically, equivalent to factorizing // Prod(s \forall s in shape if strides[s] > 0) / divisor. -MTL::Size get_2d_grid_dims( - const std::vector& shape, - const std::vector& strides, - size_t divisor); +MTL::Size +get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor); inline NS::String* make_string(std::ostringstream& os) { std::string string = os.str(); diff --git a/mlx/einsum.cpp b/mlx/einsum.cpp index f809cc1e0..ce14b2315 100644 --- a/mlx/einsum.cpp +++ b/mlx/einsum.cpp @@ -381,7 +381,7 @@ array batch_tensordot( size2 *= x.shape(s); } - std::vector shape; + Shape shape; for (auto ax : i) { shape.push_back(x.shape(ax)); } @@ -391,7 +391,7 @@ array batch_tensordot( return reshape(transpose(x, reorder, s), std::move(shape), s); }; - std::vector out_shape; + Shape out_shape; for (auto ax : a_batch) { out_shape.push_back(a.shape(ax)); } @@ -455,7 +455,7 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) { axes.push_back(i); } } - std::vector idx_shape(n_expand--, 1); + Shape idx_shape(n_expand--, 1); idx_shape[0] = in.shape(axes.back()); auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s); for (int i = 0; i < v; ++i) { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 0b5f96151..53262f1ad 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1014,7 +1014,7 @@ std::string write_signature( } if (shape_infos[i].strides) { kernel_source += - (" const constant size_t* " + name + "_strides [[buffer(" + + (" const constant int64_t* " + name + "_strides [[buffer(" + std::to_string(index) + ")]],\n"); index++; } @@ -1144,7 +1144,7 @@ MetalKernelFunction metal_kernel( shape_infos = std::move(shape_infos), attributes = std::move(attributes)]( const std::vector& inputs, - const std::vector>& output_shapes, + const std::vector& output_shapes, const std::vector& output_dtypes, std::tuple grid, std::tuple threadgroup, diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 67ba37c13..f0d41bf0f 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -12,7 +12,7 @@ namespace mlx::core::fft { array fft_impl( const array& a, - std::vector n, + Shape n, const std::vector& axes, bool real, bool inverse, @@ -59,7 +59,7 @@ array fft_impl( throw std::invalid_argument(msg.str()); } - std::vector in_shape = a.shape(); + auto in_shape = a.shape(); for (int i = 0; i < valid_axes.size(); ++i) { in_shape[valid_axes[i]] = n[i]; } @@ -76,13 +76,12 @@ array fft_impl( auto in = a; if (any_less) { - in = slice(in, std::vector(in.ndim(), 0), in_shape, s); + in = slice(in, Shape(in.ndim(), 0), in_shape, s); } if (any_greater) { // Pad with zeros auto tmp = zeros(in_shape, a.dtype(), s); - std::vector starts(in.ndim(), 0); - in = slice_update(tmp, in, starts, in.shape()); + in = slice_update(tmp, in, Shape(in.ndim(), 0), in.shape()); } auto out_shape = in_shape; @@ -106,7 +105,7 @@ array fft_impl( bool real, bool inverse, StreamOrDevice s) { - std::vector n; + Shape n; for (auto ax : axes) { n.push_back(a.shape(ax)); } @@ -124,7 +123,7 @@ array fft_impl(const array& a, bool real, bool inverse, StreamOrDevice s) { array fftn( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s /* = {} */) { return fft_impl(a, n, axes, false, false, s); @@ -141,7 +140,7 @@ array fftn(const array& a, StreamOrDevice s /* = {} */) { array ifftn( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s /* = {} */) { return fft_impl(a, n, axes, false, true, s); @@ -158,7 +157,7 @@ array ifftn(const array& a, StreamOrDevice s /* = {} */) { array rfftn( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s /* = {} */) { return fft_impl(a, n, axes, true, false, s); @@ -175,7 +174,7 @@ array rfftn(const array& a, StreamOrDevice s /* = {} */) { array irfftn( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s /* = {} */) { return fft_impl(a, n, axes, true, true, s); diff --git a/mlx/fft.h b/mlx/fft.h index 06298f821..2f02da73b 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -13,7 +13,7 @@ namespace mlx::core::fft { /** Compute the n-dimensional Fourier Transform. */ array fftn( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s = {}); array fftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); @@ -22,7 +22,7 @@ array fftn(const array& a, StreamOrDevice s = {}); /** Compute the n-dimensional inverse Fourier Transform. */ array ifftn( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s = {}); array ifftn( @@ -50,7 +50,7 @@ inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) { /** Compute the two-dimensional Fourier Transform. */ inline array fft2( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s = {}) { return fftn(a, n, axes, s); @@ -65,7 +65,7 @@ inline array fft2( /** Compute the two-dimensional inverse Fourier Transform. */ inline array ifft2( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s = {}) { return ifftn(a, n, axes, s); @@ -80,7 +80,7 @@ inline array ifft2( /** Compute the n-dimensional Fourier Transform on a real input. */ array rfftn( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s = {}); array rfftn( @@ -92,7 +92,7 @@ array rfftn(const array& a, StreamOrDevice s = {}); /** Compute the n-dimensional inverse of `rfftn`. */ array irfftn( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s = {}); array irfftn( @@ -119,7 +119,7 @@ inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) { /** Compute the two-dimensional Fourier Transform on a real input. */ inline array rfft2( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s = {}) { return rfftn(a, n, axes, s); @@ -134,7 +134,7 @@ inline array rfft2( /** Compute the two-dimensional inverse of `rfft2`. */ inline array irfft2( const array& a, - const std::vector& n, + const Shape& n, const std::vector& axes, StreamOrDevice s = {}) { return irfftn(a, n, axes, s); diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index 0e5d3f5a1..825e03304 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -138,7 +138,7 @@ SafetensorsLoad load_safetensors( continue; } const std::string& dtype = item.value().at("dtype"); - const std::vector& shape = item.value().at("shape"); + const Shape& shape = item.value().at("shape"); const std::vector& data_offsets = item.value().at("data_offsets"); Dtype type = dtype_from_safetensor_str(dtype); auto loaded_array = array( diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 577a17851..af772ce61 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -856,14 +856,7 @@ array concatenate( "[concatenate] No arrays provided for concatenation"); } - // Normalize the given axis - auto ax = axis < 0 ? axis + arrays[0].ndim() : axis; - if (ax < 0 || ax >= arrays[0].ndim()) { - std::ostringstream msg; - msg << "[concatenate] Invalid axis (" << axis << ") passed to concatenate" - << " for array with shape " << arrays[0].shape() << "."; - throw std::invalid_argument(msg.str()); - } + auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] "); auto throw_invalid_shapes = [&]() { std::ostringstream msg; @@ -925,12 +918,15 @@ array stack( int axis, StreamOrDevice s /* = {} */) { if (arrays.empty()) { - throw std::invalid_argument("No arrays provided for stacking"); + throw std::invalid_argument("[stack] No arrays provided for stacking"); } - if (!is_same_shape(arrays)) { - throw std::invalid_argument("All arrays must have the same shape"); + if (!std::all_of(arrays.begin(), arrays.end(), [&](const auto& a) { + return arrays[0].shape() == a.shape(); + })) { + throw std::invalid_argument("[stack] All arrays must have the same shape"); } - int normalized_axis = normalize_axis(axis, arrays[0].ndim() + 1); + auto normalized_axis = + normalize_axis_index(axis, arrays[0].ndim() + 1, "[stack] "); std::vector new_arrays; new_arrays.reserve(arrays.size()); for (auto& a : arrays) { @@ -945,7 +941,7 @@ array stack(const std::vector& arrays, StreamOrDevice s /* = {} */) { /** array repeat with axis */ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) { - axis = normalize_axis(axis, arr.ndim()); + axis = normalize_axis_index(axis, arr.ndim(), "[repeat] "); if (repeats < 0) { throw std::invalid_argument( diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 7acb4129c..ab1c1f03b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -144,8 +144,7 @@ std::pair, std::vector> Primitive::vmap( throw std::invalid_argument(msg.str()); } -std::vector> Primitive::output_shapes( - const std::vector&) { +std::vector Primitive::output_shapes(const std::vector&) { std::ostringstream msg; msg << "[Primitive::output_shapes] "; this->print(msg); @@ -969,7 +968,7 @@ array conv_weight_backward_patches( } // padded strides (contiguous) - std::vector in_padded_strides(in.ndim(), 1); + Strides in_padded_strides(in.ndim(), 1); for (int i = in.ndim() - 2; i >= 0; --i) { in_padded_strides[i] = in_padded_strides[i + 1] * in_padded_shape[i + 1]; } @@ -984,14 +983,13 @@ array conv_weight_backward_patches( // patches are shaped as // (batch_dim, out_spatial_dims, weight_spatial_dims, in_channels) - std::vector patches_shape{ - cotan.shape().begin(), cotan.shape().end() - 1}; + Shape patches_shape{cotan.shape().begin(), cotan.shape().end() - 1}; patches_shape.insert( patches_shape.end(), wt.shape().begin() + 1, wt.shape().end()); // Resolve patch strides int n_spatial_dim = in.ndim() - 2; - std::vector patches_strides(patches_shape.size(), 1); + Strides patches_strides(patches_shape.size(), 1); patches_strides[0] = in_padded_strides[0]; for (int i = 1; i < n_spatial_dim + 1; i++) { patches_strides[i] = in_padded_strides[i] * kernel_strides[i - 1]; @@ -1095,8 +1093,8 @@ std::vector Convolution::vjp( // Handle negative padding if (has_neg_padding) { - std::vector starts(grad.ndim(), 0); - std::vector stops = grad.shape(); + Shape starts(grad.ndim(), 0); + auto stops = grad.shape(); for (int i = 0; i < grad.ndim() - 2; i++) { if (padding_lo[i] < 0) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 0b0359a1b..a166f164c 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1917,8 +1917,6 @@ class SliceUpdate : public UnaryPrimitive { std::vector strides_; void eval(const std::vector& inputs, array& out); - - std::tuple> prepare_slice(const array& in); }; class Softmax : public UnaryPrimitive { diff --git a/mlx/random.cpp b/mlx/random.cpp index ba6434191..a4755605c 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -34,7 +34,7 @@ array key(uint64_t seed) { } array bits( - const std::vector& shape, + const Shape& shape, int width /* 4 */, const std::optional& key_ /*= nullopt */, StreamOrDevice s /* = {} */) { @@ -45,7 +45,7 @@ array bits( << "."; throw std::invalid_argument(msg.str()); } - if (key.shape() != std::vector{2}) { + if (key.shape() != Shape{2}) { std::ostringstream msg; msg << "[bits] Expected key shape (2) but received " << key.shape() << "."; throw std::invalid_argument(msg.str()); @@ -118,7 +118,7 @@ array above_minus_one_with_default(Dtype dtype) { array uniform( const array& low, const array& high, - const std::vector& shape, + const Shape& shape, Dtype dtype /* = float32 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { @@ -168,7 +168,7 @@ array uniform( } array uniform( - const std::vector& shape, + const Shape& shape, Dtype dtype, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { @@ -177,7 +177,7 @@ array uniform( } array normal( - const std::vector& shape, + const Shape& shape, Dtype dtype, const float loc /* = 0.0 */, const float scale /* = 1.0 */, @@ -201,7 +201,7 @@ array normal( array multivariate_normal( const array& mean, const array& cov, - const std::vector& shape, + const Shape& shape, Dtype dtype, const std::optional& key /* = nullopt */, StreamOrDevice s) { @@ -234,12 +234,9 @@ array multivariate_normal( } // Compute output shape - std::vector truncated_output_shape; - auto truncated_mean_shape = - std::vector(mean.shape().begin(), mean.shape().end() - 1); - auto truncated_cov_shape = - std::vector(cov.shape().begin(), cov.shape().end() - 2); + Shape(mean.shape().begin(), mean.shape().end() - 1); + auto truncated_cov_shape = Shape(cov.shape().begin(), cov.shape().end() - 2); auto output_shape = broadcast_shapes(truncated_cov_shape, truncated_mean_shape); output_shape = broadcast_shapes(output_shape, shape); @@ -269,7 +266,7 @@ array multivariate_normal( array randint( const array& low, const array& high, - const std::vector& shape, + const Shape& shape, Dtype dtype /* = int32 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { @@ -283,7 +280,7 @@ array randint( array bernoulli( const array& p, - const std::vector& shape, + const Shape& shape, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { if (!issubdtype(p.dtype(), floating)) { @@ -322,7 +319,7 @@ array bernoulli( array truncated_normal( const array& lower, const array& upper, - const std::vector& shape, + const Shape& shape, Dtype dtype /* = float32 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { @@ -357,7 +354,7 @@ array truncated_normal( } array gumbel( - const std::vector& shape, + const Shape& shape, Dtype dtype /* = float32 */, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { @@ -380,7 +377,7 @@ int get_valid_axis(int axis, int ndim) { array categorical_impl( const array& logits, int axis, - const std::vector& shape, + const Shape& shape, const std::optional& key /*= nullopt */, StreamOrDevice s) { auto gumbel_shape = shape; @@ -393,7 +390,7 @@ array categorical_impl( array categorical( const array& logits, int axis, - const std::vector& shape, + const Shape& shape, const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { // Validate and normalize axis @@ -439,7 +436,7 @@ array categorical( } array laplace( - const std::vector& shape, + const Shape& shape, Dtype dtype, const float loc /* = 0.0 */, const float scale /* = 1.0 */, diff --git a/mlx/random.h b/mlx/random.h index d4d827230..b2c821736 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -42,12 +42,12 @@ void seed(uint64_t seed); /** Generate an array with type uint32 filled with random bits. */ array bits( - const std::vector& shape, + const Shape& shape, int width, const std::optional& key = std::nullopt, StreamOrDevice s = {}); inline array bits( - const std::vector& shape, + const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return bits(shape, 4, key, s); @@ -63,7 +63,7 @@ array split(const array& key, int num, StreamOrDevice s = {}); array uniform( const array& low, const array& high, - const std::vector& shape, + const Shape& shape, Dtype dtype = float32, const std::optional& key = std::nullopt, StreamOrDevice s = {}); @@ -72,7 +72,7 @@ template array uniform( T low, U high, - const std::vector& shape, + const Shape& shape, Dtype dtype = float32, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { @@ -81,12 +81,12 @@ array uniform( /** Generate uniform random numbers between 0 and 1. */ array uniform( - const std::vector& shape, + const Shape& shape, Dtype dtype, const std::optional& key = std::nullopt, StreamOrDevice s = {}); inline array uniform( - const std::vector& shape, + const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return uniform(shape, float32, key); @@ -94,14 +94,14 @@ inline array uniform( /** Generate samples from the standard normal distribution. */ array normal( - const std::vector& shape, + const Shape& shape, Dtype dtype, const float loc, const float scale, const std::optional& key = std::nullopt, StreamOrDevice s = {}); inline array normal( - const std::vector& shape, + const Shape& shape, const float loc, const float scale, const std::optional& key = std::nullopt, @@ -109,14 +109,14 @@ inline array normal( return normal(shape, float32, loc, scale, key, s); } inline array normal( - const std::vector& shape, + const Shape& shape, const Dtype dtype, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return normal(shape, dtype, 0.0, 1.0, key, s); } inline array normal( - const std::vector& shape, + const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return normal(shape, float32, 0.0, 1.0, key, s); @@ -126,7 +126,7 @@ inline array normal( array multivariate_normal( const array& mean, const array& cov, - const std::vector& shape, + const Shape& shape, Dtype dtype, const std::optional& key = std::nullopt, StreamOrDevice s = {}); @@ -135,7 +135,7 @@ array multivariate_normal( array randint( const array& low, const array& high, - const std::vector& shape, + const Shape& shape, Dtype dtype = int32, const std::optional& key = std::nullopt, StreamOrDevice s = {}); @@ -144,7 +144,7 @@ template array randint( T low, U high, - const std::vector& shape, + const Shape& shape, Dtype dtype = int32, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { @@ -154,7 +154,7 @@ array randint( /** Generate binary variables with probability to be true equal to p */ array bernoulli( const array& p, - const std::vector& shape, + const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}); array bernoulli( @@ -173,7 +173,7 @@ array bernoulli( template array bernoulli( T p, - const std::vector& shape, + const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return bernoulli(array(p), shape, key, s); @@ -186,7 +186,7 @@ array bernoulli( array truncated_normal( const array& lower, const array& upper, - const std::vector& shape, + const Shape& shape, Dtype dtype = float32, const std::optional& key = std::nullopt, StreamOrDevice s = {}); @@ -199,7 +199,7 @@ array truncated_normal( StreamOrDevice s = {}); array gumbel( - const std::vector& shape, + const Shape& shape, Dtype dtype = float32, const std::optional& key = std::nullopt, StreamOrDevice s = {}); @@ -207,7 +207,7 @@ array gumbel( array categorical( const array& logits, int axis, - const std::vector& shape, + const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}); @@ -226,14 +226,14 @@ array categorical( /** Generate samples from the laplace distribution. */ array laplace( - const std::vector& shape, + const Shape& shape, Dtype dtype, const float loc, const float scale, const std::optional& key = std::nullopt, StreamOrDevice s = {}); inline array laplace( - const std::vector& shape, + const Shape& shape, const float loc, const float scale, const std::optional& key = std::nullopt, @@ -241,14 +241,14 @@ inline array laplace( return laplace(shape, float32, loc, scale, key, s); } inline array laplace( - const std::vector& shape, + const Shape& shape, const Dtype dtype, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return laplace(shape, dtype, 0.0, 1.0, key, s); } inline array laplace( - const std::vector& shape, + const Shape& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return laplace(shape, float32, 0.0, 1.0, key, s); diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 7a24b2f94..e1b94aa28 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -681,7 +681,7 @@ std::pair, std::vector> vmap_trace( std::vector s_inputs; for (int i = 0; i < inputs.size(); ++i) { if (in_axes[i] != -1) { - std::vector shape = inputs[i].shape(); + auto shape = inputs[i].shape(); shape.erase(shape.begin() + in_axes[i]); array in(shape, inputs[i].dtype(), nullptr, {}); s_inputs.push_back(in); @@ -924,7 +924,7 @@ std::function(const std::vector&)> custom_function( : default_stream(default_device()); // Make the output info - std::vector> shapes; + std::vector shapes; std::vector dtypes; for (const auto& out : outputs) { shapes.emplace_back(out.shape()); diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 3956e2a24..daa90fea6 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -98,29 +98,17 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) { return out_shape; } -bool is_same_shape(const std::vector& arrays) { - if (arrays.empty()) { - return true; - } - return std::all_of(arrays.begin() + 1, arrays.end(), [&](const array& a) { - return (a.shape() == arrays[0].shape()); - }); -} - -int normalize_axis(int axis, int ndim) { - if (ndim <= 0) { - throw std::invalid_argument("Number of dimensions must be positive."); - } +int normalize_axis_index( + int axis, + int ndim, + const std::string& msg_prefix /* = "" */) { if (axis < -ndim || axis >= ndim) { std::ostringstream msg; - msg << "Axis " << axis << " is out of bounds for array with " << ndim - << " dimensions."; + msg << msg_prefix << "Axis " << axis << " is out of bounds for array with " + << ndim << " dimensions."; throw std::invalid_argument(msg.str()); } - if (axis < 0) { - axis += ndim; - } - return axis; + return axis < 0 ? axis + ndim : axis; } std::ostream& operator<<(std::ostream& os, const Device& d) { @@ -323,15 +311,6 @@ std::ostream& operator<<(std::ostream& os, const Strides& v) { return os; } -std::ostream& operator<<(std::ostream& os, const std::vector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - namespace env { int get_var(const char* name, int default_value) { diff --git a/mlx/utils.h b/mlx/utils.h index 5e0ef2222..108fdf203 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -64,30 +64,13 @@ Dtype result_type(const std::vector& arrays); Shape broadcast_shapes(const Shape& s1, const Shape& s2); -bool is_same_shape(const std::vector& arrays); - -/** Returns the shape dimension if it's within allowed range. */ -template -int check_shape_dim(const T dim) { - constexpr bool is_signed = std::numeric_limits::is_signed; - using U = std::conditional_t; - constexpr U min = static_cast(std::numeric_limits::min()); - constexpr U max = static_cast(std::numeric_limits::max()); - - if ((is_signed && dim < min) || dim > max) { - throw std::invalid_argument( - "Shape dimension falls outside supported `int` range."); - } - - return static_cast(dim); -} - /** * Returns the axis normalized to be in the range [0, ndim). - * Based on numpy's normalize_axis_index. See - * https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html */ -int normalize_axis(int axis, int ndim); +int normalize_axis_index( + int axis, + int ndim, + const std::string& msg_prefix = ""); std::ostream& operator<<(std::ostream& os, const Device& d); std::ostream& operator<<(std::ostream& os, const Stream& s); @@ -96,7 +79,6 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); std::ostream& operator<<(std::ostream& os, array a); std::ostream& operator<<(std::ostream& os, const Shape& v); std::ostream& operator<<(std::ostream& os, const Strides& v); -std::ostream& operator<<(std::ostream& os, const std::vector& v); inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; } diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 6cd874b65..46547ec0c 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -27,10 +27,18 @@ struct ndarray_traits { static constexpr dlpack::dtype bfloat16{4, 16, 1}; }; // namespace nanobind +int check_shape_dim(int64_t dim) { + if (dim > std::numeric_limits::max()) { + throw std::invalid_argument( + "Shape dimension falls outside supported `int` range."); + } + return static_cast(dim); +} + template array nd_array_to_mlx_contiguous( nb::ndarray nd_array, - const std::vector& shape, + const Shape& shape, Dtype dtype) { // Make a copy of the numpy buffer // Get buffer ptr pass to array constructor @@ -42,7 +50,7 @@ array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype) { // Compute the shape and size - std::vector shape; + Shape shape; for (int i = 0; i < nd_array.ndim(); i++) { shape.push_back(check_shape_dim(nd_array.shape(i))); } @@ -108,13 +116,12 @@ nb::ndarray mlx_to_nd_array_impl( a.eval(); } std::vector shape(a.shape().begin(), a.shape().end()); - std::vector strides(a.strides().begin(), a.strides().end()); return nb::ndarray( a.data(), a.ndim(), shape.data(), /* owner= */ nb::none(), - strides.data(), + a.strides().data(), t.value_or(nb::dtype())); } @@ -272,7 +279,7 @@ void fill_vector(T list, std::vector& vals) { template PyScalarT validate_shape( T list, - const std::vector& shape, + const Shape& shape, int idx, bool& all_python_primitive_elements) { if (idx >= shape.size()) { @@ -340,7 +347,7 @@ PyScalarT validate_shape( } template -void get_shape(T list, std::vector& shape) { +void get_shape(T list, Shape& shape) { shape.push_back(check_shape_dim(nb::len(list))); if (shape.back() > 0) { auto l = list.begin(); @@ -351,7 +358,7 @@ void get_shape(T list, std::vector& shape) { } else if (nb::isinstance(*l)) { auto arr = nb::cast(*l); for (int i = 0; i < arr.ndim(); i++) { - shape.push_back(check_shape_dim(arr.shape(i))); + shape.push_back(arr.shape(i)); } return; } @@ -363,7 +370,7 @@ array array_from_list_impl( T pl, const PyScalarT& inferred_type, std::optional specified_type, - const std::vector& shape) { + const Shape& shape) { // Make the array switch (inferred_type) { case pybool: { @@ -420,7 +427,7 @@ array array_from_list_impl( template array array_from_list_impl(T pl, std::optional dtype) { // Compute the shape - std::vector shape; + Shape shape; get_shape(pl, shape); // Validate the shape and type diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a17c9ea0b..eb69b2659 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2953,16 +2953,16 @@ void init_ops(nb::module_& m) { m.def( "as_strided", [](const array& a, - std::optional> shape, - std::optional> strides, + std::optional shape, + std::optional strides, size_t offset, StreamOrDevice s) { - std::vector a_shape = (shape) ? *shape : a.shape(); - std::vector a_strides; + auto a_shape = (shape) ? *shape : a.shape(); + Strides a_strides; if (strides) { a_strides = *strides; } else { - a_strides = std::vector(a_shape.size(), 1); + a_strides = Strides(a_shape.size(), 1); for (int i = a_shape.size() - 1; i > 0; i--) { a_strides[i - 1] = a_shape[i] * a_strides[i]; } diff --git a/tests/arg_reduce_tests.cpp b/tests/arg_reduce_tests.cpp index 55b966f78..124fd67ed 100644 --- a/tests/arg_reduce_tests.cpp +++ b/tests/arg_reduce_tests.cpp @@ -11,7 +11,7 @@ void test_arg_reduce_small( Device d, const array& x, ArgReduce::ReduceType r, - std::vector out_shape, + Shape out_shape, int axis, std::vector expected_output) { auto s = default_stream(d); @@ -27,7 +27,7 @@ void test_arg_reduce_small( void test_arg_reduce_against_cpu( const array& x, ArgReduce::ReduceType r, - std::vector out_shape, + Shape out_shape, int axis) { auto y1 = array( out_shape, @@ -125,7 +125,7 @@ TEST_CASE("test arg reduce against cpu") { void test_arg_reduce_small_bool( Device d, ArgReduce::ReduceType r, - std::vector out_shape, + Shape out_shape, int axis, std::vector expected_output) { auto s = default_stream(d); diff --git a/tests/array_tests.cpp b/tests/array_tests.cpp index f91c0a313..b31da9899 100644 --- a/tests/array_tests.cpp +++ b/tests/array_tests.cpp @@ -13,10 +13,10 @@ TEST_CASE("test array basics") { array x(1.0); CHECK_EQ(x.size(), 1); CHECK_EQ(x.ndim(), 0); - CHECK_EQ(x.shape(), std::vector{}); + CHECK_EQ(x.shape(), Shape{}); CHECK_THROWS_AS(x.shape(0), std::out_of_range); CHECK_THROWS_AS(x.shape(-1), std::out_of_range); - CHECK_EQ(x.strides(), std::vector{}); + CHECK_EQ(x.strides(), Strides{}); CHECK_EQ(x.itemsize(), sizeof(float)); CHECK_EQ(x.nbytes(), sizeof(float)); CHECK_EQ(x.dtype(), float32); @@ -39,12 +39,12 @@ TEST_CASE("test array basics") { CHECK_EQ(x.dtype(), float32); CHECK_EQ(x.size(), 1); CHECK_EQ(x.ndim(), 1); - CHECK_EQ(x.shape(), std::vector{1}); + CHECK_EQ(x.shape(), Shape{1}); CHECK_EQ(x.shape(0), 1); CHECK_EQ(x.shape(-1), 1); CHECK_THROWS_AS(x.shape(1), std::out_of_range); CHECK_THROWS_AS(x.shape(-2), std::out_of_range); - CHECK_EQ(x.strides(), std::vector{1}); + CHECK_EQ(x.strides(), Strides{1}); CHECK_EQ(x.item(), 1.0); // Check empty array @@ -57,7 +57,7 @@ TEST_CASE("test array basics") { x = array({1.0, 1.0}); CHECK_EQ(x.size(), 2); - CHECK_EQ(x.shape(), std::vector{2}); + CHECK_EQ(x.shape(), Shape{2}); CHECK_EQ(x.itemsize(), sizeof(float)); CHECK_EQ(x.nbytes(), x.itemsize() * x.size()); @@ -65,9 +65,9 @@ TEST_CASE("test array basics") { CHECK_THROWS_AS(x.item(), std::invalid_argument); x = array({1.0, 1.0, 1.0}, {1, 3}); - CHECK(x.size() == 3); - CHECK(x.shape() == std::vector{1, 3}); - CHECK(x.strides() == std::vector{3, 1}); + CHECK_EQ(x.size(), 3); + CHECK_EQ(x.shape(), Shape{1, 3}); + CHECK_EQ(x.strides(), Strides{3, 1}); // Test wrong size/shapes throw: CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {4}), std::invalid_argument); @@ -472,7 +472,7 @@ TEST_CASE("test array metadata") { x = array({1.0f, 2.0f, 3.0f}, {1, 3}); y = slice(x, {0, 0}, {1, 2}, {2, 3}); eval(y); - CHECK_EQ(y.shape(), std::vector{1, 1}); + CHECK_EQ(y.shape(), Shape{1, 1}); CHECK_EQ(y.data_size(), 1); CHECK_EQ(y.flags().contiguous, true); CHECK_EQ(y.flags().row_contiguous, true); @@ -481,7 +481,7 @@ TEST_CASE("test array metadata") { x = array({0.0f, 1.0f, 2.0f, 3.0f}, {1, 4}); y = slice(x, {0, 0}, {1, 4}, {1, 2}); eval(y); - CHECK_EQ(y.shape(), std::vector{1, 2}); + CHECK_EQ(y.shape(), Shape{1, 2}); CHECK_EQ(y.flags().contiguous, false); CHECK_EQ(y.flags().row_contiguous, false); CHECK_EQ(y.flags().col_contiguous, false); @@ -489,7 +489,7 @@ TEST_CASE("test array metadata") { x = broadcast_to(array(1.0f), {4, 10}); y = slice(x, {0, 0}, {4, 10}, {2, 2}); eval(y); - CHECK_EQ(y.shape(), std::vector{2, 5}); + CHECK_EQ(y.shape(), Shape{2, 5}); CHECK_EQ(y.data_size(), 1); CHECK_EQ(y.flags().contiguous, true); CHECK_EQ(y.flags().row_contiguous, false); @@ -566,8 +566,8 @@ TEST_CASE("test array iteration") { } TEST_CASE("test array shared buffer") { - std::vector shape = {2, 2}; - int n_elem = shape[0] * shape[1]; + Shape shape = {2, 2}; + auto n_elem = shape[0] * shape[1]; allocator::Buffer buf_b = allocator::malloc(n_elem * sizeof(float)); void* buf_b_ptr = buf_b.raw_ptr(); diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index 3fe839ca5..e5e9a270a 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -617,7 +617,7 @@ TEST_CASE("test op vjps") { axes = {0}; out = vjp(fun, array({}), array(3.0f)).second; CHECK_EQ(out.size(), 0); - CHECK_EQ(out.shape(), std::vector{0}); + CHECK_EQ(out.shape(), Shape{0}); axes = {0}; out = vjp(fun, ones({2, 2, 2}), array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2})) @@ -725,9 +725,9 @@ TEST_CASE("test gather and take grads") { } TEST_CASE("test slice grads") { - std::vector start = {5, 0, 0}; - std::vector stop = {7, 2, 4}; - std::vector strides = {1, 1, 1}; + Shape start = {5, 0, 0}; + Shape stop = {7, 2, 4}; + Shape strides = {1, 1, 1}; auto fn = [&start, &stop, &strides](array input) { return slice(input, start, stop, strides); @@ -982,8 +982,8 @@ TEST_CASE("test comparison grads") { TEST_CASE("test as_strided grads") { auto x = ones({11}); - std::vector shape = {5, 5}; - std::vector strides = {1, 1}; + Shape shape = {5, 5}; + Strides strides = {1, 1}; size_t offset = 0; auto fun = [&shape, &strides, &offset](array x) { diff --git a/tests/blas_tests.cpp b/tests/blas_tests.cpp index f700e8c8a..37653ba04 100644 --- a/tests/blas_tests.cpp +++ b/tests/blas_tests.cpp @@ -16,7 +16,7 @@ TEST_CASE("test matmul") { a = array({1.0}); b = array({1.0}); auto out = matmul(a, b); - CHECK_EQ(out.shape(), std::vector{}); + CHECK_EQ(out.shape(), Shape{}); CHECK_EQ(out.size(), 1); CHECK_EQ(out.dtype(), float32); CHECK_EQ(out.item(), 1.0f); diff --git a/tests/creations_tests.cpp b/tests/creations_tests.cpp index 528e9fc90..8f94fa3b8 100644 --- a/tests/creations_tests.cpp +++ b/tests/creations_tests.cpp @@ -208,14 +208,14 @@ TEST_CASE("test full") { // Check zeros and ones { auto x = zeros({2, 2}, float32); - CHECK_EQ(x.shape(), std::vector{2, 2}); + CHECK_EQ(x.shape(), Shape{2, 2}); CHECK_EQ(x.ndim(), 2); CHECK_EQ(x.dtype(), float32); auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2}); CHECK(array_equal(x, y).item()); x = ones({2, 2}, float32); - CHECK_EQ(x.shape(), std::vector{2, 2}); + CHECK_EQ(x.shape(), Shape{2, 2}); CHECK_EQ(x.ndim(), 2); CHECK_EQ(x.dtype(), float32); y = array({1.0, 1.0, 1.0, 1.0}, {2, 2}); @@ -235,11 +235,11 @@ TEST_CASE("test full") { // Works for empty shape and empty array { array x = ones({}, int32); - CHECK_EQ(x.shape(), std::vector{}); + CHECK_EQ(x.shape(), Shape{}); CHECK_EQ(x.item(), 1); x = full({0}, array({})); - CHECK_EQ(x.shape(), std::vector{0}); + CHECK_EQ(x.shape(), Shape{0}); CHECK_EQ(x.size(), 0); CHECK_THROWS_AS(full({}, array({})), std::invalid_argument); diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index 84d2f20c4..c04dda1d5 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -162,35 +162,35 @@ TEST_CASE("test fftn") { x = reshape(arange(20, float32), {5, 4}); y = fft::rfftn(x); - CHECK_EQ(y.shape(), std::vector{5, 3}); + CHECK_EQ(y.shape(), Shape{5, 3}); y = fft::rfftn(x, {1, 0}); - CHECK_EQ(y.shape(), std::vector{3, 4}); + CHECK_EQ(y.shape(), Shape{3, 4}); x = reshape(arange(20, float32), {5, 4}); y = fft::irfftn(x); - CHECK_EQ(y.shape(), std::vector{5, 6}); + CHECK_EQ(y.shape(), Shape{5, 6}); y = fft::irfftn(x, {1, 0}); - CHECK_EQ(y.shape(), std::vector{8, 4}); + CHECK_EQ(y.shape(), Shape{8, 4}); } // Check the types of real ffts { x = zeros({5, 5}, float32); auto y = fft::rfft2(x); - CHECK_EQ(y.shape(), std::vector{5, 3}); + CHECK_EQ(y.shape(), Shape{5, 3}); CHECK_EQ(y.dtype(), complex64); y = fft::rfftn(x); - CHECK_EQ(y.shape(), std::vector{5, 3}); + CHECK_EQ(y.shape(), Shape{5, 3}); CHECK_EQ(y.dtype(), complex64); x = zeros({5, 5}, complex64); y = fft::irfft2(x); - CHECK_EQ(y.shape(), std::vector{5, 8}); + CHECK_EQ(y.shape(), Shape{5, 8}); CHECK_EQ(y.dtype(), float32); y = fft::irfftn(x); - CHECK_EQ(y.shape(), std::vector{5, 8}); + CHECK_EQ(y.shape(), Shape{5, 8}); CHECK_EQ(y.dtype(), float32); } } @@ -199,25 +199,25 @@ TEST_CASE("test fft with provided shape") { auto x = ones({5, 5}); auto y = fft::fft(x, 7, 0); - CHECK_EQ(y.shape(), std::vector{7, 5}); + CHECK_EQ(y.shape(), Shape{7, 5}); y = fft::fft(x, 3, 0); - CHECK_EQ(y.shape(), std::vector{3, 5}); + CHECK_EQ(y.shape(), Shape{3, 5}); y = fft::fft(x, 7, 1); - CHECK_EQ(y.shape(), std::vector{5, 7}); + CHECK_EQ(y.shape(), Shape{5, 7}); y = fft::fft(x, 3, 1); - CHECK_EQ(y.shape(), std::vector{5, 3}); + CHECK_EQ(y.shape(), Shape{5, 3}); y = fft::rfft(x, 7, 0); - CHECK_EQ(y.shape(), std::vector{4, 5}); + CHECK_EQ(y.shape(), Shape{4, 5}); y = fft::rfft(x, 3, 0); - CHECK_EQ(y.shape(), std::vector{2, 5}); + CHECK_EQ(y.shape(), Shape{2, 5}); y = fft::rfft(x, 3, 1); - CHECK_EQ(y.shape(), std::vector{5, 2}); + CHECK_EQ(y.shape(), Shape{5, 2}); } TEST_CASE("test fft vmap") { @@ -288,23 +288,23 @@ TEST_CASE("test fft grads") { astype(zeros({5, 5}), complex64), astype(zeros({5, 5}), complex64)) .second; - CHECK_EQ(vjp_out.shape(), std::vector{5, 5}); + CHECK_EQ(vjp_out.shape(), Shape{5, 5}); vjp_out = vjp([](array x) { return fft::ifftn(x); }, astype(zeros({5, 5}), complex64), astype(zeros({5, 5}), complex64)) .second; - CHECK_EQ(vjp_out.shape(), std::vector{5, 5}); + CHECK_EQ(vjp_out.shape(), Shape{5, 5}); vjp_out = vjp([](array x) { return fft::rfftn(x); }, zeros({5, 9}), astype(zeros({5, 5}), complex64)) .second; - CHECK_EQ(vjp_out.shape(), std::vector{5, 9}); + CHECK_EQ(vjp_out.shape(), Shape{5, 9}); vjp_out = vjp([](array x) { return fft::irfftn(x); }, astype(zeros({5, 5}), complex64), zeros({5, 8})) .second; - CHECK_EQ(vjp_out.shape(), std::vector{5, 5}); + CHECK_EQ(vjp_out.shape(), Shape{5, 5}); } diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index f0b34cc01..c5a1b8808 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -129,18 +129,10 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { CHECK_EQ( norm(x, -1.0, std::vector{1, 0}).item(), doctest::Approx(3.0)); - CHECK_EQ( - norm(x, 1.0, std::vector{0, 1}, true).shape(), - std::vector{1, 1}); - CHECK_EQ( - norm(x, 1.0, std::vector{1, 0}, true).shape(), - std::vector{1, 1}); - CHECK_EQ( - norm(x, -1.0, std::vector{0, 1}, true).shape(), - std::vector{1, 1}); - CHECK_EQ( - norm(x, -1.0, std::vector{1, 0}, true).shape(), - std::vector{1, 1}); + CHECK_EQ(norm(x, 1.0, std::vector{0, 1}, true).shape(), Shape{1, 1}); + CHECK_EQ(norm(x, 1.0, std::vector{1, 0}, true).shape(), Shape{1, 1}); + CHECK_EQ(norm(x, -1.0, std::vector{0, 1}, true).shape(), Shape{1, 1}); + CHECK_EQ(norm(x, -1.0, std::vector{1, 0}, true).shape(), Shape{1, 1}); CHECK_EQ( norm(x, -1.0, std::vector{-2, -1}, false).item(), @@ -286,9 +278,9 @@ TEST_CASE("test SVD factorization") { const auto& S = outs[1]; const auto& Vt = outs[2]; - CHECK_EQ(U.shape(), std::vector{5, 5}); - CHECK_EQ(S.shape(), std::vector{4}); - CHECK_EQ(Vt.shape(), std::vector{4, 4}); + CHECK_EQ(U.shape(), Shape{5, 5}); + CHECK_EQ(S.shape(), Shape{4}); + CHECK_EQ(Vt.shape(), Shape{4, 4}); const auto U_slice = slice(U, {0, 0}, {U.shape(0), S.shape(0)}); diff --git a/tests/load_tests.cpp b/tests/load_tests.cpp index b972f08d7..1531ce060 100644 --- a/tests/load_tests.cpp +++ b/tests/load_tests.cpp @@ -32,11 +32,11 @@ TEST_CASE("test save_safetensors") { CHECK_EQ(dict.count("test2"), 1); array test = dict.at("test"); CHECK_EQ(test.dtype(), float32); - CHECK_EQ(test.shape(), std::vector({4})); + CHECK_EQ(test.shape(), Shape{4}); CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item()); array test2 = dict.at("test2"); CHECK_EQ(test2.dtype(), float32); - CHECK_EQ(test2.shape(), std::vector({2, 2})); + CHECK_EQ(test2.shape(), Shape{2, 2}); CHECK(array_equal(test2, ones({2, 2})).item()); } diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 3b9a11e6f..545f5e24c 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -15,13 +15,13 @@ using namespace mlx::core; TEST_CASE("test copy") { array x(1.0); auto y = copy(x); - CHECK_EQ(y.shape(), std::vector{}); + CHECK_EQ(y.shape(), Shape{}); CHECK_NE(y.id(), x.id()); CHECK_EQ(y.item(), 1.0f); x = array({1, 2}, {2, 1}); y = copy(x); - CHECK_EQ(y.shape(), std::vector{2, 1}); + CHECK_EQ(y.shape(), Shape{2, 1}); CHECK_EQ(y.dtype(), int32); CHECK_NE(y.id(), x.id()); CHECK(array_equal(y, x).item()); @@ -29,37 +29,37 @@ TEST_CASE("test copy") { TEST_CASE("test reshape") { array x(1.0); - CHECK_EQ(reshape(x, {}).shape(), std::vector{}); + CHECK_EQ(reshape(x, {}).shape(), Shape{}); CHECK_THROWS_AS(reshape(x, {2}), std::invalid_argument); auto y = reshape(x, {1, 1, 1}); - CHECK_EQ(y.shape(), std::vector{1, 1, 1}); + CHECK_EQ(y.shape(), Shape{1, 1, 1}); y = reshape(x, {-1, 1, 1}); - CHECK_EQ(y.shape(), std::vector{1, 1, 1}); + CHECK_EQ(y.shape(), Shape{1, 1, 1}); y = reshape(x, {1, 1, -1}); - CHECK_EQ(y.shape(), std::vector{1, 1, 1}); + CHECK_EQ(y.shape(), Shape{1, 1, 1}); CHECK_THROWS_AS(reshape(x, {1, -1, -1}), std::invalid_argument); CHECK_THROWS_AS(reshape(x, {2, -1}), std::invalid_argument); x = zeros({2, 2, 2}); y = reshape(x, {8}); - CHECK_EQ(y.shape(), std::vector{8}); + CHECK_EQ(y.shape(), Shape{8}); CHECK_THROWS_AS(reshape(x, {7}), std::invalid_argument); y = reshape(x, {-1}); - CHECK_EQ(y.shape(), std::vector{8}); + CHECK_EQ(y.shape(), Shape{8}); y = reshape(x, {-1, 2}); - CHECK_EQ(y.shape(), std::vector{4, 2}); + CHECK_EQ(y.shape(), Shape{4, 2}); CHECK_THROWS_AS(reshape(x, {-1, 7}), std::invalid_argument); // Works with empty array x = array({}); y = reshape(x, {0, 0, 0}); - CHECK_EQ(y.shape(), std::vector{0, 0, 0}); + CHECK_EQ(y.shape(), Shape{0, 0, 0}); y.eval(); CHECK_EQ(y.size(), 0); CHECK_THROWS_AS(reshape(x, {}), std::invalid_argument); CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument); y = reshape(x, {1, 5, 0}); - CHECK_EQ(y.shape(), std::vector{1, 5, 0}); + CHECK_EQ(y.shape(), Shape{1, 5, 0}); // Check that reshaping a transposed array doesn't result in a copy x = reshape(arange(64), {2, 4, 8}); @@ -138,15 +138,15 @@ TEST_CASE("test reshape") { TEST_CASE("test flatten") { array x = zeros({2, 3, 4}); - CHECK_EQ(flatten(x).shape(), std::vector({2 * 3 * 4})); + CHECK_EQ(flatten(x).shape(), Shape({2 * 3 * 4})); - CHECK_EQ(flatten(x, 1, 1).shape(), std::vector({2, 3, 4})); - CHECK_EQ(flatten(x, 1, 2).shape(), std::vector({2, 3 * 4})); - CHECK_EQ(flatten(x, 1, 3).shape(), std::vector({2, 3 * 4})); - CHECK_EQ(flatten(x, 1, -1).shape(), std::vector({2, 3 * 4})); - CHECK_EQ(flatten(x, -2, -1).shape(), std::vector({2, 3 * 4})); - CHECK_EQ(flatten(x, -3, -1).shape(), std::vector({2 * 3 * 4})); - CHECK_EQ(flatten(x, -4, -1).shape(), std::vector({2 * 3 * 4})); + CHECK_EQ(flatten(x, 1, 1).shape(), Shape({2, 3, 4})); + CHECK_EQ(flatten(x, 1, 2).shape(), Shape({2, 3 * 4})); + CHECK_EQ(flatten(x, 1, 3).shape(), Shape({2, 3 * 4})); + CHECK_EQ(flatten(x, 1, -1).shape(), Shape({2, 3 * 4})); + CHECK_EQ(flatten(x, -2, -1).shape(), Shape({2, 3 * 4})); + CHECK_EQ(flatten(x, -3, -1).shape(), Shape({2 * 3 * 4})); + CHECK_EQ(flatten(x, -4, -1).shape(), Shape({2 * 3 * 4})); // Check start > end throws CHECK_THROWS(flatten(x, 2, 1)); @@ -159,17 +159,17 @@ TEST_CASE("test flatten") { // Check scalar flattens to 1D x = array(1); - CHECK_EQ(flatten(x, -3, -1).shape(), std::vector({1})); - CHECK_EQ(flatten(x, 0, 0).shape(), std::vector({1})); + CHECK_EQ(flatten(x, -3, -1).shape(), Shape({1})); + CHECK_EQ(flatten(x, 0, 0).shape(), Shape({1})); } TEST_CASE("test squeeze and expand") { array x = zeros({2, 1, 2, 1, 2, 1}); - CHECK_EQ(squeeze(x).shape(), std::vector{2, 2, 2}); - CHECK_EQ(squeeze(x, {1, 3, 5}).shape(), std::vector{2, 2, 2}); - CHECK_EQ(squeeze(x, {-1, -3, -5}).shape(), std::vector{2, 2, 2}); - CHECK_EQ(squeeze(x, 1).shape(), std::vector{2, 2, 1, 2, 1}); - CHECK_EQ(squeeze(x, -1).shape(), std::vector{2, 1, 2, 1, 2}); + CHECK_EQ(squeeze(x).shape(), Shape{2, 2, 2}); + CHECK_EQ(squeeze(x, {1, 3, 5}).shape(), Shape{2, 2, 2}); + CHECK_EQ(squeeze(x, {-1, -3, -5}).shape(), Shape{2, 2, 2}); + CHECK_EQ(squeeze(x, 1).shape(), Shape{2, 2, 1, 2, 1}); + CHECK_EQ(squeeze(x, -1).shape(), Shape{2, 1, 2, 1, 2}); CHECK_THROWS(squeeze(x, 0)); CHECK_THROWS(squeeze(x, 2)); @@ -177,13 +177,13 @@ TEST_CASE("test squeeze and expand") { CHECK_THROWS(squeeze(x, {1, 3, -3})); x = zeros({2, 2}); - CHECK_EQ(expand_dims(x, 0).shape(), std::vector{1, 2, 2}); - CHECK_EQ(expand_dims(x, -1).shape(), std::vector{2, 2, 1}); - CHECK_EQ(expand_dims(x, 1).shape(), std::vector{2, 1, 2}); - CHECK_EQ(expand_dims(x, {0, 1, 2}).shape(), std::vector{1, 1, 1, 2, 2}); + CHECK_EQ(expand_dims(x, 0).shape(), Shape{1, 2, 2}); + CHECK_EQ(expand_dims(x, -1).shape(), Shape{2, 2, 1}); + CHECK_EQ(expand_dims(x, 1).shape(), Shape{2, 1, 2}); + CHECK_EQ(expand_dims(x, {0, 1, 2}).shape(), Shape{1, 1, 1, 2, 2}); CHECK_EQ( expand_dims(x, {0, 1, 2, 5, 6, 7}).shape(), - std::vector{1, 1, 1, 2, 2, 1, 1, 1}); + Shape{1, 1, 1, 2, 2, 1, 1, 1}); CHECK_THROWS(expand_dims(x, 3)); CHECK_THROWS(expand_dims(x, -4)); @@ -210,7 +210,7 @@ TEST_CASE("test slice") { out = slice(x, {1}, {0}); eval(out); - CHECK_EQ(out.shape(), std::vector{0}); + CHECK_EQ(out.shape(), Shape{0}); out = slice(x, {0}, {1}, {1}); CHECK_EQ(out.item(), 3); @@ -353,7 +353,7 @@ TEST_CASE("test split") { out = split(x, 3, -1); CHECK_EQ(out.size(), 3); for (auto i = 0; i < 3; ++i) { - CHECK_EQ(out[i].shape(), std::vector{1}); + CHECK_EQ(out[i].shape(), Shape{1}); CHECK_EQ(out[i].dtype(), int32); CHECK_EQ(out[i].item(), i); } @@ -370,13 +370,13 @@ TEST_CASE("test split") { x = zeros({8, 12}); out = split(x, 2); CHECK_EQ(out.size(), 2); - CHECK_EQ(out[0].shape(), std::vector{4, 12}); - CHECK_EQ(out[1].shape(), std::vector{4, 12}); + CHECK_EQ(out[0].shape(), Shape{4, 12}); + CHECK_EQ(out[1].shape(), Shape{4, 12}); out = split(x, 3, 1); CHECK_EQ(out.size(), 3); - CHECK_EQ(out[0].shape(), std::vector{8, 4}); - CHECK_EQ(out[1].shape(), std::vector{8, 4}); - CHECK_EQ(out[2].shape(), std::vector{8, 4}); + CHECK_EQ(out[0].shape(), Shape{8, 4}); + CHECK_EQ(out[1].shape(), Shape{8, 4}); + CHECK_EQ(out[2].shape(), Shape{8, 4}); out = split(x, std::vector{}); CHECK_EQ(out.size(), 1); @@ -384,25 +384,25 @@ TEST_CASE("test split") { out = split(x, {3, 7}); CHECK_EQ(out.size(), 3); - CHECK_EQ(out[0].shape(), std::vector{3, 12}); - CHECK_EQ(out[1].shape(), std::vector{4, 12}); - CHECK_EQ(out[2].shape(), std::vector{1, 12}); + CHECK_EQ(out[0].shape(), Shape{3, 12}); + CHECK_EQ(out[1].shape(), Shape{4, 12}); + CHECK_EQ(out[2].shape(), Shape{1, 12}); out = split(x, std::vector{20}); CHECK_EQ(out.size(), 2); - CHECK_EQ(out[0].shape(), std::vector{8, 12}); - CHECK_EQ(out[1].shape(), std::vector{0, 12}); + CHECK_EQ(out[0].shape(), Shape{8, 12}); + CHECK_EQ(out[1].shape(), Shape{0, 12}); // Negative indices out = split(x, std::vector{-5}); - CHECK_EQ(out[0].shape(), std::vector{3, 12}); - CHECK_EQ(out[1].shape(), std::vector{5, 12}); + CHECK_EQ(out[0].shape(), Shape{3, 12}); + CHECK_EQ(out[1].shape(), Shape{5, 12}); // Different axis out = split(x, std::vector{2, 8}, 1); - CHECK_EQ(out[0].shape(), std::vector{8, 2}); - CHECK_EQ(out[1].shape(), std::vector{8, 6}); - CHECK_EQ(out[2].shape(), std::vector{8, 4}); + CHECK_EQ(out[0].shape(), Shape{8, 2}); + CHECK_EQ(out[1].shape(), Shape{8, 6}); + CHECK_EQ(out[2].shape(), Shape{8, 4}); // Out of order indices x = arange(5); @@ -420,18 +420,18 @@ TEST_CASE("test swap and move axes") { a = zeros({2}); CHECK_THROWS(swapaxes(a, 0, 1)); - CHECK_EQ(swapaxes(a, 0, 0).shape(), std::vector{2}); - CHECK_EQ(swapaxes(a, -1, -1).shape(), std::vector{2}); + CHECK_EQ(swapaxes(a, 0, 0).shape(), Shape{2}); + CHECK_EQ(swapaxes(a, -1, -1).shape(), Shape{2}); a = zeros({2, 3, 4}); CHECK_THROWS(swapaxes(a, 0, -4)); CHECK_THROWS(swapaxes(a, 0, 3)); CHECK_THROWS(swapaxes(a, 3, 0)); CHECK_THROWS(swapaxes(a, -4, 0)); - CHECK_EQ(swapaxes(a, 0, 2).shape(), std::vector{4, 3, 2}); - CHECK_EQ(swapaxes(a, 0, 1).shape(), std::vector{3, 2, 4}); - CHECK_EQ(swapaxes(a, 0, -1).shape(), std::vector{4, 3, 2}); - CHECK_EQ(swapaxes(a, -2, 2).shape(), std::vector{2, 4, 3}); + CHECK_EQ(swapaxes(a, 0, 2).shape(), Shape{4, 3, 2}); + CHECK_EQ(swapaxes(a, 0, 1).shape(), Shape{3, 2, 4}); + CHECK_EQ(swapaxes(a, 0, -1).shape(), Shape{4, 3, 2}); + CHECK_EQ(swapaxes(a, -2, 2).shape(), Shape{2, 4, 3}); // Test moveaxis a = array(0.0); @@ -439,36 +439,36 @@ TEST_CASE("test swap and move axes") { a = zeros({2}); CHECK_THROWS(moveaxis(a, 0, 1)); - CHECK_EQ(moveaxis(a, 0, 0).shape(), std::vector{2}); - CHECK_EQ(moveaxis(a, -1, -1).shape(), std::vector{2}); + CHECK_EQ(moveaxis(a, 0, 0).shape(), Shape{2}); + CHECK_EQ(moveaxis(a, -1, -1).shape(), Shape{2}); a = zeros({2, 3, 4}); CHECK_THROWS(moveaxis(a, 0, -4)); CHECK_THROWS(moveaxis(a, 0, 3)); CHECK_THROWS(moveaxis(a, 3, 0)); CHECK_THROWS(moveaxis(a, -4, 0)); - CHECK_EQ(moveaxis(a, 0, 2).shape(), std::vector{3, 4, 2}); - CHECK_EQ(moveaxis(a, 0, 1).shape(), std::vector{3, 2, 4}); - CHECK_EQ(moveaxis(a, 0, -1).shape(), std::vector{3, 4, 2}); - CHECK_EQ(moveaxis(a, -2, 2).shape(), std::vector{2, 4, 3}); + CHECK_EQ(moveaxis(a, 0, 2).shape(), Shape{3, 4, 2}); + CHECK_EQ(moveaxis(a, 0, 1).shape(), Shape{3, 2, 4}); + CHECK_EQ(moveaxis(a, 0, -1).shape(), Shape{3, 4, 2}); + CHECK_EQ(moveaxis(a, -2, 2).shape(), Shape{2, 4, 3}); } TEST_CASE("test transpose") { array x(1); auto y = transpose(x); - CHECK_EQ(y.shape(), std::vector{}); + CHECK_EQ(y.shape(), Shape{}); CHECK_EQ(y.item(), 1); CHECK_THROWS_AS(transpose(x, {0}), std::invalid_argument); CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument); x = array({1}, {1}); y = transpose(x); - CHECK_EQ(y.shape(), std::vector{1}); + CHECK_EQ(y.shape(), Shape{1}); CHECK_EQ(y.item(), 1); // Negative indices y = transpose(x, {-1}); - CHECK_EQ(y.shape(), std::vector{1}); + CHECK_EQ(y.shape(), Shape{1}); CHECK_EQ(y.item(), 1); CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument); @@ -477,24 +477,24 @@ TEST_CASE("test transpose") { // Works with empty array x = array({}); y = transpose(x); - CHECK_EQ(y.shape(), std::vector{0}); + CHECK_EQ(y.shape(), Shape{0}); y.eval(); CHECK_EQ(y.size(), 0); x = array({1, 2, 3, 4, 5, 6}, {2, 3}); y = transpose(x); - CHECK_EQ(y.shape(), std::vector{3, 2}); + CHECK_EQ(y.shape(), Shape{3, 2}); y = transpose(x, {-1, 0}); - CHECK_EQ(y.shape(), std::vector{3, 2}); + CHECK_EQ(y.shape(), Shape{3, 2}); y = transpose(x, {-1, -2}); - CHECK_EQ(y.shape(), std::vector{3, 2}); + CHECK_EQ(y.shape(), Shape{3, 2}); y.eval(); CHECK(array_equal(y, array({1, 4, 2, 5, 3, 6}, {3, 2})).item()); y = transpose(x, {0, 1}); - CHECK_EQ(y.shape(), std::vector{2, 3}); + CHECK_EQ(y.shape(), Shape{2, 3}); CHECK(array_equal(y, x).item()); y = transpose(x, {0, -1}); - CHECK_EQ(y.shape(), std::vector{2, 3}); + CHECK_EQ(y.shape(), Shape{2, 3}); CHECK(array_equal(y, x).item()); CHECK_THROWS_AS(transpose(x, {}), std::invalid_argument); @@ -505,19 +505,19 @@ TEST_CASE("test transpose") { x = array({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 3, 2}); y = transpose(x); - CHECK_EQ(y.shape(), std::vector{2, 3, 2}); + CHECK_EQ(y.shape(), Shape{2, 3, 2}); auto expected = array({1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12}, {2, 3, 2}); CHECK(array_equal(y, expected).item()); y = transpose(x, {0, 1, 2}); - CHECK_EQ(y.shape(), std::vector{2, 3, 2}); + CHECK_EQ(y.shape(), Shape{2, 3, 2}); CHECK(array_equal(y, x).item()); y = transpose(x, {1, 0, 2}); - CHECK_EQ(y.shape(), std::vector{3, 2, 2}); + CHECK_EQ(y.shape(), Shape{3, 2, 2}); expected = array({1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}, {3, 2, 2}); CHECK(array_equal(y, expected).item()); y = transpose(x, {0, 2, 1}); - CHECK_EQ(y.shape(), std::vector{2, 2, 3}); + CHECK_EQ(y.shape(), Shape{2, 2, 3}); expected = array({1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}, {2, 2, 3}); CHECK(array_equal(y, expected).item()); @@ -542,7 +542,7 @@ TEST_CASE("test comparison ops") { array y({}); auto z = x == y; CHECK_EQ(z.dtype(), bool_); - CHECK_EQ(z.shape(), std::vector{0}); + CHECK_EQ(z.shape(), Shape{0}); } // Basic cases @@ -631,7 +631,7 @@ TEST_CASE("test comparison ops") { auto y = zeros({2, 1}); auto z = equal(x, y); CHECK_EQ(z.dtype(), bool_); - CHECK_EQ(z.shape(), std::vector{2, 2}); + CHECK_EQ(z.shape(), Shape{2, 2}); auto expected = array({true, true, true, true}, {2, 2}); CHECK(array_equal(z, expected).item()); @@ -639,7 +639,7 @@ TEST_CASE("test comparison ops") { y = array({1.0, 2.0}, {2, 1}); z = equal(x, y); CHECK_EQ(z.dtype(), bool_); - CHECK_EQ(z.shape(), std::vector{2, 2}); + CHECK_EQ(z.shape(), Shape{2, 2}); expected = array({true, false, false, true}, {2, 2}); CHECK(array_equal(z, expected).item()); @@ -769,15 +769,15 @@ TEST_CASE("test reduction ops") { CHECK_THROWS_AS(sum(x, 0), std::out_of_range); CHECK_THROWS_AS(sum(x, -1), std::out_of_range); out = sum(x, std::vector{}); - CHECK_EQ(out.shape(), std::vector{}); + CHECK_EQ(out.shape(), Shape{}); CHECK_EQ(out.size(), 1); x = array({}); out = sum(x); - CHECK_EQ(out.shape(), std::vector{}); + CHECK_EQ(out.shape(), Shape{}); CHECK_EQ(out.size(), 1); out = sum(x, true); - CHECK_EQ(out.shape(), std::vector{1}); + CHECK_EQ(out.shape(), Shape{1}); out = sum(x, std::vector{}); CHECK_EQ(out.shape(), x.shape()); @@ -788,7 +788,7 @@ TEST_CASE("test reduction ops") { CHECK_EQ(out.ndim(), 0); out = sum(x, -1, true); CHECK_EQ(out.ndim(), 1); - CHECK_EQ(out.shape(), std::vector{1}); + CHECK_EQ(out.shape(), Shape{1}); CHECK_THROWS_AS(sum(x, 1), std::out_of_range); CHECK_THROWS_AS(sum(x, -2), std::out_of_range); @@ -797,21 +797,21 @@ TEST_CASE("test reduction ops") { x = zeros({2, 3, 4}); out = sum(x, {0, 2}); - CHECK_EQ(out.shape(), std::vector{3}); + CHECK_EQ(out.shape(), Shape{3}); out = sum(x, std::vector{}); CHECK_EQ(out.shape(), x.shape()); out = sum(x, {0, -1}); - CHECK_EQ(out.shape(), std::vector{3}); + CHECK_EQ(out.shape(), Shape{3}); out = sum(x, {0, -1}, true); - CHECK_EQ(out.shape(), std::vector{1, 3, 1}); + CHECK_EQ(out.shape(), Shape{1, 3, 1}); out = sum(x, true); - CHECK_EQ(out.shape(), std::vector{1, 1, 1}); + CHECK_EQ(out.shape(), Shape{1, 1, 1}); out = sum(x); - CHECK_EQ(out.shape(), std::vector{}); + CHECK_EQ(out.shape(), Shape{}); CHECK_THROWS_AS(sum(x, 3), std::out_of_range); CHECK_THROWS_AS(sum(x, -4), std::out_of_range); @@ -986,7 +986,7 @@ TEST_CASE("test reduction ops") { std::vector nums = {0.0f, 1.0f, 2.0f, 3.0f}; x = array(nums.data(), {2, 2}); auto y = logsumexp(x, {0, 1}, true); - CHECK_EQ(y.shape(), std::vector{1, 1}); + CHECK_EQ(y.shape(), Shape{1, 1}); auto result = std::log( std::exp(nums[0]) + std::exp(nums[1]) + std::exp(nums[2]) + std::exp(nums[3])); @@ -1594,7 +1594,7 @@ TEST_CASE("test arithmetic binary ops") { x = array({1.0, 2.0, 3.0}, {1, 3}); y = array({1.0, 2.0, 3.0}, {1, 3}); z = add(x, y); - CHECK_EQ(z.shape(), std::vector{1, 3}); + CHECK_EQ(z.shape(), Shape{1, 3}); auto eq = array_equal(z, array({2.0, 4.0, 6.0}, {1, 3})); CHECK(eq.item()); @@ -1626,13 +1626,13 @@ TEST_CASE("test arithmetic binary ops") { x = array({1.0, 2.0}, {1, 2}); y = array({1.0, 2.0}, {2, 1}); z = add(x, y); - CHECK_EQ(z.shape(), std::vector{2, 2}); + CHECK_EQ(z.shape(), Shape{2, 2}); eq = array_equal(z, array({2.0, 3.0, 3.0, 4.0}, {2, 2})); CHECK(eq.item()); x = ones({3, 2, 1}); z = x + 2.0; - CHECK_EQ(z.shape(), std::vector{3, 2, 1}); + CHECK_EQ(z.shape(), Shape{3, 2, 1}); eq = array_equal(z, array({3.0, 3.0, 3.0, 3.0, 3.0, 3.0}, {3, 2, 1})); CHECK(eq.item()); @@ -1642,7 +1642,7 @@ TEST_CASE("test arithmetic binary ops") { z = x + y; z.eval(); CHECK_EQ(z.size(), 0); - CHECK_EQ(z.shape(), std::vector{0}); + CHECK_EQ(z.shape(), Shape{0}); // Check subtraction x = array({3, 2, 1}); @@ -1725,46 +1725,46 @@ TEST_CASE("test arithmetic binary ops") { TEST_CASE("test broadcast") { auto s = broadcast_shapes({1}, {1, 2}); - CHECK_EQ(s, std::vector{1, 2}); + CHECK_EQ(s, Shape{1, 2}); s = broadcast_shapes({1, 2}, {1}); - CHECK_EQ(s, std::vector{1, 2}); + CHECK_EQ(s, Shape{1, 2}); s = broadcast_shapes({2, 2}, {}); - CHECK_EQ(s, std::vector{2, 2}); + CHECK_EQ(s, Shape{2, 2}); s = broadcast_shapes({}, {1, 1}); - CHECK_EQ(s, std::vector{1, 1}); + CHECK_EQ(s, Shape{1, 1}); s = broadcast_shapes({1, 2, 1}, {2}); - CHECK_EQ(s, std::vector{1, 2, 2}); + CHECK_EQ(s, Shape{1, 2, 2}); s = broadcast_shapes({2}, {1, 2, 1}); - CHECK_EQ(s, std::vector{1, 2, 2}); + CHECK_EQ(s, Shape{1, 2, 2}); s = broadcast_shapes({2, 2, 2}, {1, 2, 1}); - CHECK_EQ(s, std::vector{2, 2, 2}); + CHECK_EQ(s, Shape{2, 2, 2}); s = broadcast_shapes({2, 2, 2, 1}, {1, 2, 1}); - CHECK_EQ(s, std::vector{2, 2, 2, 1}); + CHECK_EQ(s, Shape{2, 2, 2, 1}); s = broadcast_shapes({0}, {0, 0}); - CHECK_EQ(s, std::vector{0, 0}); + CHECK_EQ(s, Shape{0, 0}); - CHECK_EQ(broadcast_shapes({}, {0}), std::vector{0}); + CHECK_EQ(broadcast_shapes({}, {0}), Shape{0}); s = broadcast_shapes({5, 0}, {0, 5, 0}); - CHECK_EQ(s, std::vector{0, 5, 0}); - - CHECK_EQ(broadcast_shapes({}, {0}), std::vector{0}); - CHECK_EQ(broadcast_shapes({1}, {0}), std::vector{0}); - CHECK_EQ(broadcast_shapes({1}, {0}), std::vector{0}); - CHECK_EQ(broadcast_shapes({1}, {0, 0}), std::vector{0, 0}); - CHECK_EQ(broadcast_shapes({1, 1}, {0}), std::vector{1, 0}); - CHECK_EQ(broadcast_shapes({1, 1}, {0, 0}), std::vector{0, 0}); - CHECK_EQ(broadcast_shapes({2, 1}, {1, 0}), std::vector{2, 0}); - CHECK_EQ(broadcast_shapes({2, 1}, {2, 0}), std::vector{2, 0}); - CHECK_EQ(broadcast_shapes({2, 1}, {1, 2, 0}), std::vector{1, 2, 0}); + CHECK_EQ(s, Shape{0, 5, 0}); + + CHECK_EQ(broadcast_shapes({}, {0}), Shape{0}); + CHECK_EQ(broadcast_shapes({1}, {0}), Shape{0}); + CHECK_EQ(broadcast_shapes({1}, {0}), Shape{0}); + CHECK_EQ(broadcast_shapes({1}, {0, 0}), Shape{0, 0}); + CHECK_EQ(broadcast_shapes({1, 1}, {0}), Shape{1, 0}); + CHECK_EQ(broadcast_shapes({1, 1}, {0, 0}), Shape{0, 0}); + CHECK_EQ(broadcast_shapes({2, 1}, {1, 0}), Shape{2, 0}); + CHECK_EQ(broadcast_shapes({2, 1}, {2, 0}), Shape{2, 0}); + CHECK_EQ(broadcast_shapes({2, 1}, {1, 2, 0}), Shape{1, 2, 0}); CHECK_THROWS_AS(broadcast_shapes({2}, {0}), std::invalid_argument); CHECK_THROWS_AS(broadcast_shapes({2, 1}, {0, 0}), std::invalid_argument); @@ -1778,19 +1778,19 @@ TEST_CASE("test broadcast") { CHECK_EQ(broadcast_to(x, {1, 1}).item(), 2.3f); x = broadcast_to(x, {5, 1}); - CHECK_EQ(x.shape(), std::vector{5, 1}); + CHECK_EQ(x.shape(), Shape{5, 1}); x.eval(); - CHECK_EQ(x.strides(), std::vector{0, 0}); + CHECK_EQ(x.strides(), Strides{0, 0}); CHECK_THROWS_AS(broadcast_to(x, {1, 5}), std::invalid_argument); x = broadcast_to(x, {5, 5}); - CHECK_EQ(x.shape(), std::vector{5, 5}); + CHECK_EQ(x.shape(), Shape{5, 5}); x = zeros({2, 1, 2}); x = broadcast_to(x, {4, 2, 1, 2}); - CHECK_EQ(x.shape(), std::vector{4, 2, 1, 2}); + CHECK_EQ(x.shape(), Shape{4, 2, 1, 2}); x.eval(); - CHECK_EQ(x.strides(), std::vector{0, 2, 0, 1}); + CHECK_EQ(x.strides(), Strides{0, 2, 0, 1}); // Broadcast on empty arrays works as expected x = array({}); @@ -1801,29 +1801,29 @@ TEST_CASE("test broadcast") { auto y = broadcast_to(x, {0}); eval(y); CHECK_EQ(y.size(), 0); - CHECK_EQ(y.shape(), std::vector{0}); + CHECK_EQ(y.shape(), Shape{0}); x = array({1, 2}, {2, 1}); y = broadcast_to(x, {2, 0}); eval(y); CHECK_EQ(y.size(), 0); - CHECK_EQ(y.shape(), std::vector{2, 0}); + CHECK_EQ(y.shape(), Shape{2, 0}); // Check repeat application works x = zeros({2}); x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2}); - CHECK_EQ(x.shape(), std::vector{2, 2}); + CHECK_EQ(x.shape(), Shape{2, 2}); x.eval(); - CHECK_EQ(x.strides(), std::vector{0, 1}); + CHECK_EQ(x.strides(), Strides{0, 1}); x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2, 2}); - CHECK_EQ(x.shape(), std::vector{2, 2, 2}); + CHECK_EQ(x.shape(), Shape{2, 2, 2}); x.eval(); - CHECK_EQ(x.strides(), std::vector{0, 0, 1}); + CHECK_EQ(x.strides(), Strides{0, 0, 1}); // Broadcast on transposed array works x = array({0, 1, 2, 3, 4, 5}, {2, 3}); x = broadcast_to(transpose(x), {2, 3, 2}); - CHECK_EQ(x.shape(), std::vector{2, 3, 2}); + CHECK_EQ(x.shape(), Shape{2, 3, 2}); y = broadcast_to(array({0, 3, 1, 4, 2, 5}, {3, 2}), {2, 3, 2}); CHECK(array_equal(x, y).item()); @@ -1867,16 +1867,16 @@ TEST_CASE("test gather") { auto x = arange(20); auto y = arange(10); auto out = gather(x, y, 0, {1}); - CHECK_EQ(out.shape(), std::vector{10, 1}); + CHECK_EQ(out.shape(), Shape{10, 1}); CHECK(array_equal(reshape(out, {-1}), y).item()); out = gather(x, array({15}, uint32), 0, {1}); - CHECK_EQ(out.shape(), std::vector{1, 1}); + CHECK_EQ(out.shape(), Shape{1, 1}); CHECK_EQ(out.item(), 15); // No index gather works out = gather(x, {}, std::vector{}, {10}); - CHECK_EQ(out.shape(), std::vector{10}); + CHECK_EQ(out.shape(), Shape{10}); CHECK(array_equal(out, arange(10)).item()); // Basic test of correctness with 2D input @@ -1884,13 +1884,13 @@ TEST_CASE("test gather") { x = reshape(x, {4, 32}); y = array({0, 1}, uint32); out = gather(x, y, 0, {1, 32}); - CHECK_EQ(out.shape(), std::vector{2, 1, 32}); + CHECK_EQ(out.shape(), Shape{2, 1, 32}); CHECK(array_equal(reshape(out, {64}), arange(64)).item()); x = reshape(x, {64, 2}); y = array({0}, uint32); out = gather(x, y, 0, {64, 1}); - CHECK_EQ(out.shape(), std::vector{1, 64, 1}); + CHECK_EQ(out.shape(), Shape{1, 64, 1}); CHECK(array_equal(out, reshape(arange(0, 128, 2), {1, 64, 1})).item()); // Basic test of correctness with 3D input @@ -1898,7 +1898,7 @@ TEST_CASE("test gather") { x = reshape(x, {8, 4, 8}); y = array({0}, uint32); out = gather(x, y, 0, {8, 1, 1}); - CHECK_EQ(out.shape(), std::vector{1, 8, 1, 1}); + CHECK_EQ(out.shape(), Shape{1, 8, 1, 1}); CHECK( array_equal(out, reshape(arange(0, 256, 32), {1, 8, 1, 1})).item()); @@ -1913,10 +1913,10 @@ TEST_CASE("test take") { // Empty takes auto empty = astype(array({}), int32); auto z = take(array({1}), empty); - CHECK_EQ(z.shape(), std::vector{0}); + CHECK_EQ(z.shape(), Shape{0}); empty = reshape(empty, {1, 0, 1}); z = take(array({1}), empty); - CHECK_EQ(z.shape(), std::vector{1, 0, 1}); + CHECK_EQ(z.shape(), Shape{1, 0, 1}); CHECK_THROWS(take(array({}), array(1))); @@ -1926,7 +1926,7 @@ TEST_CASE("test take") { // Take a single row auto x = reshape(arange(256), {8, 4, 8}); z = take(x, array({0}, uint32), 0); - CHECK_EQ(z.shape(), std::vector{1, 4, 8}); + CHECK_EQ(z.shape(), Shape{1, 4, 8}); z = reshape(z, {32}); CHECK(array_equal(z, arange(32)).item()); @@ -2017,12 +2017,12 @@ TEST_CASE("test take along axis") { out = take_along_axis(a, reshape(array({1}), {1, 1}), 0); eval(out); // Make sure it runs - CHECK_EQ(out.shape(), std::vector{1, 0}); + CHECK_EQ(out.shape(), Shape{1, 0}); auto inds = reshape(astype(array({}), int32), {1, 0}); out = take_along_axis(a, inds, 0); eval(out); // Make sure it runs - CHECK_EQ(out.shape(), std::vector{1, 0}); + CHECK_EQ(out.shape(), Shape{1, 0}); a = array({1, 2, 3, 4}, {2, 2}); inds = array({0, 1}, {1, 2}); @@ -2084,7 +2084,7 @@ TEST_CASE("test put along axis") { auto inds = reshape(astype(array({}), int32), {1, 0}); out = take_along_axis(a, inds, 0); eval(out); // Make sure it runs - CHECK_EQ(out.shape(), std::vector{1, 0}); + CHECK_EQ(out.shape(), Shape{1, 0}); a = array({1, 2, 3, 4}, {2, 2}); inds = array({0, 1}, {1, 2}); @@ -2506,9 +2506,9 @@ TEST_CASE("test scan op") { TEST_CASE("test pad") { auto x = zeros({1, 2, 3}); - CHECK_EQ(pad(x, 1).shape(), std::vector{3, 4, 5}); - CHECK_EQ(pad(x, {0, 1}).shape(), std::vector{2, 3, 4}); - CHECK_EQ(pad(x, {{1, 1}, {1, 2}, {3, 1}}).shape(), std::vector{3, 5, 7}); + CHECK_EQ(pad(x, 1).shape(), Shape{3, 4, 5}); + CHECK_EQ(pad(x, {0, 1}).shape(), Shape{2, 3, 4}); + CHECK_EQ(pad(x, {{1, 1}, {1, 2}, {3, 1}}).shape(), Shape{3, 5, 7}); x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); auto padded_x = pad(x, 1); @@ -2647,20 +2647,20 @@ TEST_CASE("test where") { TEST_CASE("test stack") { auto x = array({}); - CHECK_EQ(stack({x}, 0).shape(), std::vector{1, 0}); - CHECK_EQ(stack({x}, 1).shape(), std::vector{0, 1}); + CHECK_EQ(stack({x}, 0).shape(), Shape{1, 0}); + CHECK_EQ(stack({x}, 1).shape(), Shape{0, 1}); x = array({1, 2, 3}, {3}); - CHECK_EQ(stack({x}, 0).shape(), std::vector{1, 3}); - CHECK_EQ(stack({x}, 1).shape(), std::vector{3, 1}); + CHECK_EQ(stack({x}, 0).shape(), Shape{1, 3}); + CHECK_EQ(stack({x}, 1).shape(), Shape{3, 1}); auto y = array({4, 5, 6}, {3}); auto z = std::vector{x, y}; - CHECK_EQ(stack(z).shape(), std::vector{2, 3}); - CHECK_EQ(stack(z, 0).shape(), std::vector{2, 3}); - CHECK_EQ(stack(z, 1).shape(), std::vector{3, 2}); - CHECK_EQ(stack(z, -1).shape(), std::vector{3, 2}); - CHECK_EQ(stack(z, -2).shape(), std::vector{2, 3}); + CHECK_EQ(stack(z).shape(), Shape{2, 3}); + CHECK_EQ(stack(z, 0).shape(), Shape{2, 3}); + CHECK_EQ(stack(z, 1).shape(), Shape{3, 2}); + CHECK_EQ(stack(z, -1).shape(), Shape{3, 2}); + CHECK_EQ(stack(z, -2).shape(), Shape{2, 3}); CHECK_THROWS_MESSAGE(stack({}, 0), "No arrays provided for stacking"); @@ -2676,20 +2676,20 @@ TEST_CASE("test stack") { TEST_CASE("test eye") { auto eye_3 = eye(3); - CHECK_EQ(eye_3.shape(), std::vector{3, 3}); + CHECK_EQ(eye_3.shape(), Shape{3, 3}); auto expected_eye_3 = array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3}); CHECK(array_equal(eye_3, expected_eye_3).item()); auto eye_3x2 = eye(3, 2); - CHECK_EQ(eye_3x2.shape(), std::vector{3, 2}); + CHECK_EQ(eye_3x2.shape(), Shape{3, 2}); auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2}); CHECK(array_equal(eye_3x2, expected_eye_3x2).item()); } TEST_CASE("test tri") { auto _tri = tri(4, 4, 0, float32); - CHECK_EQ(_tri.shape(), std::vector{4, 4}); + CHECK_EQ(_tri.shape(), Shape{4, 4}); auto expected_tri = array( {1.0f, 0.0f, @@ -2712,8 +2712,8 @@ TEST_CASE("test tri") { } TEST_CASE("test tril") { - auto _tril = tril(full(std::vector{4, 4}, 2.0f, float32), 0); - CHECK_EQ(_tril.shape(), std::vector{4, 4}); + auto _tril = tril(full({4, 4}, 2.0f, float32), 0); + CHECK_EQ(_tril.shape(), Shape{4, 4}); auto expected_tri = array( {2.0f, 0.0f, @@ -2736,8 +2736,8 @@ TEST_CASE("test tril") { } TEST_CASE("test triu") { - auto _triu = triu(full(std::vector{4, 4}, 2.0f, float32), 0); - CHECK_EQ(_triu.shape(), std::vector{4, 4}); + auto _triu = triu(full({4, 4}, 2.0f, float32), 0); + CHECK_EQ(_triu.shape(), Shape{4, 4}); auto expected_tri = array( {2.0f, 2.0f, @@ -2761,7 +2761,7 @@ TEST_CASE("test triu") { TEST_CASE("test identity") { auto id_4 = identity(4); - CHECK_EQ(id_4.shape(), std::vector{4, 4}); + CHECK_EQ(id_4.shape(), Shape{4, 4}); auto expected_id_4 = array( {1.0f, 0.0f, @@ -2785,7 +2785,7 @@ TEST_CASE("test identity") { TEST_CASE("test eye with positive k offset") { auto eye_3_k1 = eye(3, 4, 1); - CHECK_EQ(eye_3_k1.shape(), std::vector{3, 4}); + CHECK_EQ(eye_3_k1.shape(), Shape{3, 4}); auto expected_eye_3_k1 = array( {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 4}); @@ -2794,7 +2794,7 @@ TEST_CASE("test eye with positive k offset") { TEST_CASE("test eye with negative k offset") { auto eye_4_k_minus1 = eye(4, 3, -1); - CHECK_EQ(eye_4_k_minus1.shape(), std::vector{4, 3}); + CHECK_EQ(eye_4_k_minus1.shape(), Shape{4, 3}); auto expected_eye_4_k_minus1 = array( {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {4, 3}); @@ -2844,9 +2844,9 @@ TEST_CASE("test quantize dequantize") { for (int i = 2; i <= 8; i *= 2) { int el_per_int = 32 / i; auto [x_q, scales, biases] = quantize(x, 128, i); - CHECK_EQ(x_q.shape(), std::vector{128, 512 / el_per_int}); - CHECK_EQ(scales.shape(), std::vector{128, 4}); - CHECK_EQ(biases.shape(), std::vector{128, 4}); + CHECK_EQ(x_q.shape(), Shape{128, 512 / el_per_int}); + CHECK_EQ(scales.shape(), Shape{128, 4}); + CHECK_EQ(biases.shape(), Shape{128, 4}); auto x_hat = dequantize(x_q, scales, biases, 128, i); auto max_diff = max(abs(x - x_hat)).item(); @@ -3081,7 +3081,7 @@ TEST_CASE("test diagonal") { out = diagonal(x, -5, 0, 1); eval(out); - CHECK_EQ(out.shape(), std::vector{0}); + CHECK_EQ(out.shape(), Shape{0}); x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 2, 2}); out = diagonal(x, 1, 0, 1); @@ -3337,17 +3337,17 @@ TEST_CASE("test atleast_1d") { auto x = array(1); auto out = atleast_1d(x); CHECK_EQ(out.ndim(), 1); - CHECK_EQ(out.shape(), std::vector{1}); + CHECK_EQ(out.shape(), Shape{1}); x = array({1, 2, 3}, {3}); out = atleast_1d(x); CHECK_EQ(out.ndim(), 1); - CHECK_EQ(out.shape(), std::vector{3}); + CHECK_EQ(out.shape(), Shape{3}); x = array({1, 2, 3}, {3, 1}); out = atleast_1d(x); CHECK_EQ(out.ndim(), 2); - CHECK_EQ(out.shape(), std::vector{3, 1}); + CHECK_EQ(out.shape(), Shape{3, 1}); } TEST_CASE("test atleast_1d vector") { @@ -3356,28 +3356,28 @@ TEST_CASE("test atleast_1d vector") { auto out = atleast_1d(x); CHECK_EQ(out.size(), 3); CHECK_EQ(out[0].ndim(), 1); - CHECK_EQ(out[0].shape(), std::vector{1}); + CHECK_EQ(out[0].shape(), Shape{1}); CHECK_EQ(out[1].ndim(), 1); - CHECK_EQ(out[1].shape(), std::vector{3}); + CHECK_EQ(out[1].shape(), Shape{3}); CHECK_EQ(out[2].ndim(), 2); - CHECK_EQ(out[2].shape(), std::vector{3, 1}); + CHECK_EQ(out[2].shape(), Shape{3, 1}); } TEST_CASE("test atleast_2d") { auto x = array(1); auto out = atleast_2d(x); CHECK_EQ(out.ndim(), 2); - CHECK_EQ(out.shape(), std::vector{1, 1}); + CHECK_EQ(out.shape(), Shape{1, 1}); x = array({1, 2, 3}, {3}); out = atleast_2d(x); CHECK_EQ(out.ndim(), 2); - CHECK_EQ(out.shape(), std::vector{1, 3}); + CHECK_EQ(out.shape(), Shape{1, 3}); x = array({1, 2, 3}, {3, 1}); out = atleast_2d(x); CHECK_EQ(out.ndim(), 2); - CHECK_EQ(out.shape(), std::vector{3, 1}); + CHECK_EQ(out.shape(), Shape{3, 1}); } TEST_CASE("test atleast_2d vector") { @@ -3386,28 +3386,28 @@ TEST_CASE("test atleast_2d vector") { auto out = atleast_2d(x); CHECK_EQ(out.size(), 3); CHECK_EQ(out[0].ndim(), 2); - CHECK_EQ(out[0].shape(), std::vector{1, 1}); + CHECK_EQ(out[0].shape(), Shape{1, 1}); CHECK_EQ(out[1].ndim(), 2); - CHECK_EQ(out[1].shape(), std::vector{1, 3}); + CHECK_EQ(out[1].shape(), Shape{1, 3}); CHECK_EQ(out[2].ndim(), 2); - CHECK_EQ(out[2].shape(), std::vector{3, 1}); + CHECK_EQ(out[2].shape(), Shape{3, 1}); } TEST_CASE("test atleast_3d") { auto x = array(1); auto out = atleast_3d(x); CHECK_EQ(out.ndim(), 3); - CHECK_EQ(out.shape(), std::vector{1, 1, 1}); + CHECK_EQ(out.shape(), Shape{1, 1, 1}); x = array({1, 2, 3}, {3}); out = atleast_3d(x); CHECK_EQ(out.ndim(), 3); - CHECK_EQ(out.shape(), std::vector{1, 3, 1}); + CHECK_EQ(out.shape(), Shape{1, 3, 1}); x = array({1, 2, 3}, {3, 1}); out = atleast_3d(x); CHECK_EQ(out.ndim(), 3); - CHECK_EQ(out.shape(), std::vector{3, 1, 1}); + CHECK_EQ(out.shape(), Shape{3, 1, 1}); } TEST_CASE("test atleast_3d vector") { @@ -3416,11 +3416,11 @@ TEST_CASE("test atleast_3d vector") { auto out = atleast_3d(x); CHECK_EQ(out.size(), 3); CHECK_EQ(out[0].ndim(), 3); - CHECK_EQ(out[0].shape(), std::vector{1, 1, 1}); + CHECK_EQ(out[0].shape(), Shape{1, 1, 1}); CHECK_EQ(out[1].ndim(), 3); - CHECK_EQ(out[1].shape(), std::vector{1, 3, 1}); + CHECK_EQ(out[1].shape(), Shape{1, 3, 1}); CHECK_EQ(out[2].ndim(), 3); - CHECK_EQ(out[2].shape(), std::vector{3, 1, 1}); + CHECK_EQ(out[2].shape(), Shape{3, 1, 1}); } TEST_CASE("test topk") { diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index a449fed83..7ab72c075 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -141,7 +141,7 @@ TEST_CASE("test random bits") { { auto key = array({0u, 0u, 1u, 1u}, {2, 2}); - auto shape = std::vector{3}; + auto shape = Shape{3}; auto fn = [&shape](array k) { return random::bits(shape, k); }; auto expected = array( @@ -264,7 +264,7 @@ TEST_CASE("test random uniform") { // Check broadcasting x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3}); - CHECK_EQ(x.shape(), std::vector{3, 3}); + CHECK_EQ(x.shape(), Shape{3, 3}); CHECK_THROWS_AS( random::uniform(zeros({3, 3}), 1.0, {1, 3}), std::invalid_argument); CHECK_THROWS_AS( @@ -332,11 +332,11 @@ TEST_CASE("test random uniform") { return random::uniform(low, 1, {3}, float32, k); }; auto out = vmap(fun, -1)(key, zeros({2, 3})); - CHECK_EQ(out.shape(), std::vector{2, 3}); + CHECK_EQ(out.shape(), Shape{2, 3}); key = zeros({2, 2}, uint32); out = vmap(fun)(key, zeros({2, 3})); - CHECK_EQ(out.shape(), std::vector{2, 3}); + CHECK_EQ(out.shape(), Shape{2, 3}); } // Check bounds are respected @@ -425,7 +425,7 @@ TEST_CASE("test random multivariate_normal") { auto mean = zeros({3}); auto cov = eye(3); auto x = random::multivariate_normal(mean, cov, {1000}, float32); - CHECK_EQ(x.shape(), std::vector({1000, 3})); + CHECK_EQ(x.shape(), Shape{1000, 3}); CHECK_EQ(x.dtype(), float32); } @@ -435,7 +435,7 @@ TEST_CASE("test random multivariate_normal") { auto cov = array({1., -1, -.1, 1.}); cov = reshape(cov, {2, 2}); auto x = random::multivariate_normal(mean, cov, {1}, float32); - CHECK_EQ(x.shape(), std::vector({1, 2})); + CHECK_EQ(x.shape(), Shape{1, 2}); CHECK_EQ(x.dtype(), float32); } @@ -457,7 +457,7 @@ TEST_CASE("test random multivariate_normal") { auto mean = zeros({3}); auto cov = zeros({1, 2, 3, 3}); auto x = random::multivariate_normal(mean, cov, {1000, 2}, float32); - CHECK_EQ(x.shape(), std::vector({1000, 2, 3})); + CHECK_EQ(x.shape(), Shape{1000, 2, 3}); } { auto mean = zeros({3}); @@ -537,7 +537,7 @@ TEST_CASE("test random bernoulli") { // Return array with correct shape x = random::bernoulli(0.5, {3, 3}); - CHECK_EQ(x.shape(), std::vector({3, 3})); + CHECK_EQ(x.shape(), Shape{3, 3}); // Try with p = {} x = random::bernoulli(array({})); @@ -547,7 +547,7 @@ TEST_CASE("test random bernoulli") { auto p = array({0.1, 0.2, 0.3}); p = reshape(p, {1, 3}); x = random::bernoulli(p, {4, 3}); - CHECK_EQ(x.shape(), std::vector({4, 3})); + CHECK_EQ(x.shape(), Shape{4, 3}); CHECK_THROWS_AS(random::bernoulli(array({}), {3, 3}), std::invalid_argument); @@ -572,7 +572,7 @@ TEST_CASE("Test truncated normal") { // Requested shape x = random::truncated_normal(array(-2.0), array(2.0), {3, 4}); - CHECK_EQ(x.shape(), std::vector({3, 4})); + CHECK_EQ(x.shape(), Shape{3, 4}); // Empty array x = random::truncated_normal(array({}), array({})); @@ -584,7 +584,7 @@ TEST_CASE("Test truncated normal") { x = random::truncated_normal(lower, higher); // All in bounds - CHECK_EQ(x.shape(), std::vector({3, 2})); + CHECK_EQ(x.shape(), Shape{3, 2}); CHECK((all(x <= higher).item() && all(lower <= x).item())); // high < low => all equal to low @@ -615,17 +615,17 @@ TEST_CASE("test categorical") { CHECK_THROWS(categorical(logits, 1, std::vector{11})); CHECK_THROWS(categorical(logits, 1, {10, 1})); - CHECK_EQ(categorical(logits, -1).shape(), std::vector{10}); - CHECK_EQ(categorical(logits, 0).shape(), std::vector{20}); - CHECK_EQ(categorical(logits, 1).shape(), std::vector{10}); + CHECK_EQ(categorical(logits, -1).shape(), Shape{10}); + CHECK_EQ(categorical(logits, 0).shape(), Shape{20}); + CHECK_EQ(categorical(logits, 1).shape(), Shape{10}); auto out = categorical(logits); - CHECK_EQ(out.shape(), std::vector{10}); + CHECK_EQ(out.shape(), Shape{10}); CHECK_EQ(out.dtype(), uint32); CHECK(max(out).item() < 20); out = categorical(logits, 0, {5, 20}); - CHECK_EQ(out.shape(), std::vector{5, 20}); + CHECK_EQ(out.shape(), Shape{5, 20}); CHECK(max(out).item() < 10); float inf = std::numeric_limits::infinity(); @@ -636,9 +636,9 @@ TEST_CASE("test categorical") { CHECK_EQ(categorical(logits).item(), 1); logits = zeros({5, 4, 3}); - CHECK_EQ(categorical(logits, -1, 7).shape(), std::vector{5, 4, 7}); - CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector{5, 3, 7}); - CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector{4, 3, 7}); + CHECK_EQ(categorical(logits, -1, 7).shape(), Shape{5, 4, 7}); + CHECK_EQ(categorical(logits, -2, 7).shape(), Shape{5, 3, 7}); + CHECK_EQ(categorical(logits, -3, 7).shape(), Shape{4, 3, 7}); } TEST_CASE("test laplace") { diff --git a/tests/utils_tests.cpp b/tests/utils_tests.cpp index 3f2e3f814..666a6b749 100644 --- a/tests/utils_tests.cpp +++ b/tests/utils_tests.cpp @@ -37,46 +37,9 @@ TEST_CASE("test normalize axis") { {0, 3, 0}, {1, 3, 1}, {2, 3, 2}, {-1, 3, 2}, {-2, 3, 1}, {-3, 3, 0}}; for (const auto& tc : testCases) { - CHECK_EQ(normalize_axis(tc.axis, tc.ndim), tc.expected); + CHECK_EQ(normalize_axis_index(tc.axis, tc.ndim), tc.expected); } - CHECK_THROWS(normalize_axis(3, 3)); - CHECK_THROWS(normalize_axis(-4, 3)); -} - -TEST_CASE("test is same size and shape") { - struct TestCase { - std::vector a; - bool expected; - }; - - std::vector testCases = { - {{array({}), array({})}, true}, - {{array({1}), array({1})}, true}, - {{array({1, 2, 3}), array({1, 2, 4})}, true}, - {{array({1, 2, 3}), array({1, 2})}, false}}; - - for (const auto& tc : testCases) { - CHECK_EQ(is_same_shape(tc.a), tc.expected); - } -} - -TEST_CASE("test check shape dimension") { - int dim_min = std::numeric_limits::min(); - int dim_max = std::numeric_limits::max(); - CHECK_EQ(check_shape_dim(-4), -4); - CHECK_EQ(check_shape_dim(0), 0); - CHECK_EQ(check_shape_dim(12), 12); - CHECK_EQ(check_shape_dim(static_cast(dim_min)), dim_min); - CHECK_EQ(check_shape_dim(static_cast(dim_max)), dim_max); - CHECK_EQ(check_shape_dim(static_cast(0)), 0); - CHECK_EQ(check_shape_dim(static_cast(dim_max)), dim_max); - CHECK_THROWS_AS( - check_shape_dim(static_cast(dim_min) - 1), - std::invalid_argument); - CHECK_THROWS_AS( - check_shape_dim(static_cast(dim_max) + 1), - std::invalid_argument); - CHECK_THROWS_AS( - check_shape_dim(static_cast(dim_max) + 1), std::invalid_argument); + CHECK_THROWS(normalize_axis_index(3, 3)); + CHECK_THROWS(normalize_axis_index(-4, 3)); } diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 1403d87ca..ba5b528ae 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -342,9 +342,9 @@ TEST_CASE("test vmap gather") { auto x = zeros({2, 2, 2, 2}); auto y = array({0, 1, 0, 0, 1, 0}, {2, 3}); auto out = vmap(fun, {0, -1})({x, y})[0]; - CHECK_EQ(out.shape(), std::vector{2, 2, 3, 2, 2}); + CHECK_EQ(out.shape(), Shape{2, 2, 3, 2, 2}); out = vmap(fun, {0, -1}, {3})({x, y})[0]; - CHECK_EQ(out.shape(), std::vector{2, 3, 2, 2, 2}); + CHECK_EQ(out.shape(), Shape{2, 3, 2, 2, 2}); } { @@ -358,7 +358,7 @@ TEST_CASE("test vmap gather") { auto x = zeros({2, 2, 2, 2}); auto y = array({0, 1, 0, 0, 1, 0}, {2, 3}); auto out = vmap(fun, {0, 0})({x, y})[0]; - CHECK_EQ(out.shape(), std::vector{2, 3, 2, 2}); + CHECK_EQ(out.shape(), Shape{2, 3, 2, 2}); } { @@ -373,7 +373,7 @@ TEST_CASE("test vmap gather") { auto y = array({0, 1, 0, 0, 1, 0}, {2, 3}); auto out = vmap(fun, {-1, 0})({x, y})[0]; - CHECK_EQ(out.shape(), std::vector{2, 3, 2, 2, 2}); + CHECK_EQ(out.shape(), Shape{2, 3, 2, 2, 2}); } { @@ -388,11 +388,11 @@ TEST_CASE("test vmap gather") { auto y = array({0, 1, 0, 0, 1, 0}, {2, 3}); auto z = array({0, 1, 0, 0, 1, 0}, {2, 3}); auto out = vmap(fun, {-1, 0, 0})({x, y, z})[0]; - CHECK_EQ(out.shape(), std::vector{2, 3, 2, 2}); + CHECK_EQ(out.shape(), Shape{2, 3, 2, 2}); z = array({0, 1, 0, 0, 1, 0}, {3, 2}); out = vmap(fun, {-1, 0, 1})({x, y, z})[0]; - CHECK_EQ(out.shape(), std::vector{2, 3, 2, 2}); + CHECK_EQ(out.shape(), Shape{2, 3, 2, 2}); } } @@ -483,9 +483,9 @@ TEST_CASE("test vmap SVD") { const auto& S = out.at(1); const auto& Vt = out.at(2); - CHECK_EQ(U.shape(), std::vector{a.shape(1), a.shape(0), a.shape(0)}); - CHECK_EQ(S.shape(), std::vector{a.shape(1), a.shape(2)}); - CHECK_EQ(Vt.shape(), std::vector{a.shape(1), a.shape(2), a.shape(2)}); + CHECK_EQ(U.shape(), Shape{a.shape(1), a.shape(0), a.shape(0)}); + CHECK_EQ(S.shape(), Shape{a.shape(1), a.shape(2)}); + CHECK_EQ(Vt.shape(), Shape{a.shape(1), a.shape(2), a.shape(2)}); } // vmap over the third axis. @@ -495,8 +495,8 @@ TEST_CASE("test vmap SVD") { const auto& S = out.at(1); const auto& Vt = out.at(2); - CHECK_EQ(U.shape(), std::vector{a.shape(2), a.shape(0), a.shape(0)}); - CHECK_EQ(S.shape(), std::vector{a.shape(2), a.shape(0)}); - CHECK_EQ(Vt.shape(), std::vector{a.shape(2), a.shape(1), a.shape(1)}); + CHECK_EQ(U.shape(), Shape{a.shape(2), a.shape(0), a.shape(0)}); + CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)}); + CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)}); } }